initiate
Browse files- app.py +194 -0
- optvq/data/__pycache__/dataloader.cpython-310.pyc +0 -0
- optvq/data/__pycache__/dataset.cpython-310.pyc +0 -0
- optvq/data/__pycache__/preprocessor.cpython-310.pyc +0 -0
- optvq/data/dataloader.py +50 -0
- optvq/data/dataset.py +62 -0
- optvq/data/preprocessor.py +71 -0
- optvq/losses/__pycache__/aeloss_disc.cpython-310.pyc +0 -0
- optvq/losses/aeloss.py +76 -0
- optvq/losses/aeloss_disc.py +177 -0
- optvq/models/__pycache__/discriminator.cpython-310.pyc +0 -0
- optvq/models/__pycache__/quantizer.cpython-310.pyc +0 -0
- optvq/models/__pycache__/vqgan.cpython-310.pyc +0 -0
- optvq/models/__pycache__/vqgan_hf.cpython-310.pyc +0 -0
- optvq/models/backbone/__pycache__/diffusion.cpython-310.pyc +0 -0
- optvq/models/backbone/diffusion.py +787 -0
- optvq/models/backbone/simple_cnn.py +90 -0
- optvq/models/discriminator.py +155 -0
- optvq/models/quantizer.py +260 -0
- optvq/models/vqgan.py +168 -0
- optvq/models/vqgan_hf.py +69 -0
- optvq/trainer/__pycache__/arguments.cpython-310.pyc +0 -0
- optvq/trainer/__pycache__/pipeline.cpython-310.pyc +0 -0
- optvq/trainer/arguments.py +53 -0
- optvq/trainer/pipeline.py +362 -0
- optvq/utils/__pycache__/func.cpython-310.pyc +0 -0
- optvq/utils/__pycache__/init.cpython-310.pyc +0 -0
- optvq/utils/__pycache__/logger.cpython-310.pyc +0 -0
- optvq/utils/__pycache__/metrics.cpython-310.pyc +0 -0
- optvq/utils/func.py +27 -0
- optvq/utils/init.py +36 -0
- optvq/utils/logger.py +386 -0
- optvq/utils/metrics.py +60 -0
- requirements.txt +25 -0
app.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from io import BytesIO
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import matplotlib
|
8 |
+
import numpy as np
|
9 |
+
from PIL import Image
|
10 |
+
from omegaconf import OmegaConf
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torchvision import transforms as T
|
14 |
+
|
15 |
+
from optvq.models.quantizer import sinkhorn
|
16 |
+
from optvq.utils.init import seed_everything
|
17 |
+
seed_everything(42)
|
18 |
+
from optvq.models.vqgan_hf import VQModelHF
|
19 |
+
matplotlib.rcParams['font.family'] = 'Times New Roman'
|
20 |
+
|
21 |
+
#################
|
22 |
+
N_data = 50
|
23 |
+
N_code = 20
|
24 |
+
dim = 2
|
25 |
+
handler = None
|
26 |
+
device = torch.device("cpu")
|
27 |
+
#################
|
28 |
+
|
29 |
+
def nearest(src, trg):
|
30 |
+
dis_mat = torch.cdist(src, trg)
|
31 |
+
min_idx = torch.argmin(dis_mat, dim=-1)
|
32 |
+
return min_idx
|
33 |
+
|
34 |
+
def normalize(A, dim, mode="all"):
|
35 |
+
if mode == "all":
|
36 |
+
A = (A - A.mean()) / (A.std() + 1e-6)
|
37 |
+
A = A - A.min()
|
38 |
+
elif mode == "dim":
|
39 |
+
A = A / dim
|
40 |
+
elif mode == "null":
|
41 |
+
pass
|
42 |
+
return A
|
43 |
+
|
44 |
+
def draw_NN(data, code):
|
45 |
+
# nearest neighbor method
|
46 |
+
indices = nearest(data, code)
|
47 |
+
data = data.numpy()
|
48 |
+
code = code.numpy()
|
49 |
+
|
50 |
+
plt.figure(figsize=(3, 2.5), dpi=400)
|
51 |
+
# draw arrows in blue color, alpha=0.5
|
52 |
+
for i in range(data.shape[0]):
|
53 |
+
idx = indices[i].item()
|
54 |
+
start = data[i]
|
55 |
+
end = code[idx]
|
56 |
+
plt.arrow(start[0], start[1], end[0] - start[0], end[1] - start[1],
|
57 |
+
head_width=0.05, head_length=0.05, fc='red', ec='red', alpha=0.6,
|
58 |
+
ls="-", lw=0.5)
|
59 |
+
plt.scatter(data[:, 0], data[:, 1], s=10, marker="o", c="gray", label="Data")
|
60 |
+
plt.scatter(code[:, 0], code[:, 1], s=25, marker="*", c="blue", label="Code")
|
61 |
+
plt.legend(loc="lower right")
|
62 |
+
plt.grid(color="gray", alpha=0.8, ls="-.", lw=0.5)
|
63 |
+
plt.title("Nearest neighbor")
|
64 |
+
|
65 |
+
buf = BytesIO()
|
66 |
+
plt.savefig(buf, format="png")
|
67 |
+
buf.seek(0)
|
68 |
+
image = Image.open(buf)
|
69 |
+
return image
|
70 |
+
|
71 |
+
def draw_optvq(data, code):
|
72 |
+
cost = torch.cdist(data, code, p=2.0)
|
73 |
+
cost = normalize(cost, dim, mode="all")
|
74 |
+
Q = sinkhorn(cost, n_iters=5, epsilon=10, is_distributed=False)
|
75 |
+
indices = torch.argmax(Q, dim=-1)
|
76 |
+
data = data.numpy()
|
77 |
+
code = code.numpy()
|
78 |
+
|
79 |
+
plt.figure(figsize=(3, 2.5), dpi=400)
|
80 |
+
# draw arrows in blue color, alpha=0.5
|
81 |
+
for i in range(data.shape[0]):
|
82 |
+
idx = indices[i].item()
|
83 |
+
start = data[i]
|
84 |
+
end = code[idx]
|
85 |
+
plt.arrow(start[0], start[1], end[0] - start[0], end[1] - start[1],
|
86 |
+
head_width=0.05, head_length=0.05, fc='green', ec='green', alpha=0.6,
|
87 |
+
ls="-", lw=0.5)
|
88 |
+
plt.scatter(data[:, 0], data[:, 1], s=10, marker="o", c="gray", label="Data")
|
89 |
+
plt.scatter(code[:, 0], code[:, 1], s=25, marker="*", c="blue", label="Code")
|
90 |
+
plt.legend(loc="lower right")
|
91 |
+
plt.grid(color="gray", alpha=0.8, ls="-.", lw=0.5)
|
92 |
+
plt.title("Optimal Transport (OptVQ)")
|
93 |
+
|
94 |
+
buf = BytesIO()
|
95 |
+
plt.savefig(buf, format="png")
|
96 |
+
buf.seek(0)
|
97 |
+
image = Image.open(buf)
|
98 |
+
return image
|
99 |
+
|
100 |
+
def draw_process(x, y, std):
|
101 |
+
data = torch.randn(N_data, dim)
|
102 |
+
code = torch.randn(N_code, dim) * std
|
103 |
+
code[:, 0] += x
|
104 |
+
code[:, 1] += y
|
105 |
+
|
106 |
+
image_NN = draw_NN(data, code)
|
107 |
+
image_optvq = draw_optvq(data, code)
|
108 |
+
|
109 |
+
return image_NN, image_optvq
|
110 |
+
|
111 |
+
class Handler:
|
112 |
+
def __init__(self, device):
|
113 |
+
self.transform = T.Compose([
|
114 |
+
T.Resize(256),
|
115 |
+
T.CenterCrop(256),
|
116 |
+
T.ToTensor()
|
117 |
+
])
|
118 |
+
self.device = device
|
119 |
+
|
120 |
+
|
121 |
+
self.basevq = VQModelHF.from_pretrained("BorelTHU/basevq-16x16x4")
|
122 |
+
self.basevq.to(self.device)
|
123 |
+
self.basevq.eval()
|
124 |
+
|
125 |
+
self.vqgan = VQModelHF.from_pretrained("BorelTHU/vqgan-16x16")
|
126 |
+
self.vqgan.to(self.device)
|
127 |
+
self.vqgan.eval()
|
128 |
+
|
129 |
+
self.optvq = VQModelHF.from_pretrained("BorelTHU/optvq-16x16x4")
|
130 |
+
self.optvq.to(self.device)
|
131 |
+
self.optvq.eval()
|
132 |
+
|
133 |
+
def tensor_to_image(self, tensor):
|
134 |
+
img = tensor.squeeze(0).cpu().permute(1, 2, 0).numpy()
|
135 |
+
img = (img + 1) / 2 * 255
|
136 |
+
img = img.astype("uint8")
|
137 |
+
return img
|
138 |
+
|
139 |
+
def process_image(self, img: np.ndarray):
|
140 |
+
img = Image.fromarray(img.astype("uint8"))
|
141 |
+
img = self.transform(img)
|
142 |
+
img = img.unsqueeze(0).to(self.device)
|
143 |
+
with torch.no_grad():
|
144 |
+
img = 2 * img - 1
|
145 |
+
# basevq
|
146 |
+
quant, *_ = self.basevq.encode(img)
|
147 |
+
basevq_rec = self.basevq.decode(quant)
|
148 |
+
# vqgan
|
149 |
+
quant, *_ = self.vqgan.encode(img)
|
150 |
+
vqgan_rec = self.vqgan.decode(quant)
|
151 |
+
# optvq
|
152 |
+
quant, *_ = self.optvq.encode(img)
|
153 |
+
optvq_rec = self.optvq.decode(quant)
|
154 |
+
|
155 |
+
# tensor to PIL image
|
156 |
+
img = self.tensor_to_image(img)
|
157 |
+
basevq_rec = self.tensor_to_image(basevq_rec)
|
158 |
+
vqgan_rec = self.tensor_to_image(vqgan_rec)
|
159 |
+
optvq_rec = self.tensor_to_image(optvq_rec)
|
160 |
+
return img, basevq_rec, vqgan_rec, optvq_rec
|
161 |
+
|
162 |
+
if __name__ == "__main__":
|
163 |
+
# create the model handler
|
164 |
+
handler = Handler(device=device)
|
165 |
+
|
166 |
+
# create the interface
|
167 |
+
with gr.Blocks() as demo:
|
168 |
+
gr.Textbox(value="This demo shows the image reconstruction comparison between OptVQ and other methods. The input image is resized to 256 x 256 and then fed into the models. The output images are the reconstructed images from the latent codes.", label="Demo 1: Image reconstruction results")
|
169 |
+
with gr.Row():
|
170 |
+
with gr.Column():
|
171 |
+
image_input = gr.Image(label="Input data", image_mode="RGB", type="numpy")
|
172 |
+
btn_demo1 = gr.Button(value="Run reconstruction")
|
173 |
+
image_basevq = gr.Image(label="BaseVQ rec.")
|
174 |
+
image_vqgan = gr.Image(label="VQGAN rec.")
|
175 |
+
image_optvq = gr.Image(label="OptVQ rec.")
|
176 |
+
btn_demo1.click(fn=handler.process_image, inputs=[image_input], outputs=[image_input, image_basevq, image_vqgan, image_optvq])
|
177 |
+
|
178 |
+
gr.Textbox(value="This demo shows the 2D visualizations of nearest neighbor and optimal transport (OptVQ) methods. The data points are randomly generated from a normal distribution, and the matching results are shown as arrows with different colors.", label="Demo 2: 2D visualizations of matching results")
|
179 |
+
with gr.Row():
|
180 |
+
with gr.Column():
|
181 |
+
input_x = gr.Slider(label="x", value=0, minimum=-10, maximum=10, step=0.1)
|
182 |
+
input_y = gr.Slider(label="y", value=0, minimum=-10, maximum=10, step=0.1)
|
183 |
+
input_std = gr.Slider(label="std", value=1, minimum=0, maximum=5, step=0.1)
|
184 |
+
btn_demo2 = gr.Button(value="Run 2D example")
|
185 |
+
output_nn = gr.Image(label="NN")
|
186 |
+
output_optvq = gr.Image(label="OptVQ")
|
187 |
+
|
188 |
+
# set the function
|
189 |
+
input_x.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])
|
190 |
+
input_y.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])
|
191 |
+
input_std.change(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])
|
192 |
+
btn_demo2.click(fn=draw_process, inputs=[input_x, input_y, input_std], outputs=[output_nn, output_optvq])
|
193 |
+
|
194 |
+
demo.launch()
|
optvq/data/__pycache__/dataloader.cpython-310.pyc
ADDED
Binary file (1.63 kB). View file
|
|
optvq/data/__pycache__/dataset.cpython-310.pyc
ADDED
Binary file (2.22 kB). View file
|
|
optvq/data/__pycache__/preprocessor.cpython-310.pyc
ADDED
Binary file (1.71 kB). View file
|
|
optvq/data/dataloader.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
|
3 |
+
# Copyright (c) 2024 Borui Zhang. All Rights Reserved.
|
4 |
+
# Licensed under the MIT License [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------
|
6 |
+
|
7 |
+
from typing import Union
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from torch.utils.data import Subset
|
11 |
+
|
12 |
+
def maybe_get_subset(dataset, subset_size: Union[int, float] = None, num_data_repeat: int = None):
|
13 |
+
"""
|
14 |
+
num_data_repeat is aimed at avoiding the consuming loading time for the small dataset
|
15 |
+
"""
|
16 |
+
if subset_size is None:
|
17 |
+
return dataset
|
18 |
+
else:
|
19 |
+
if subset_size < 1.0:
|
20 |
+
subset_size = len(dataset) * subset_size
|
21 |
+
subset_size = int(subset_size)
|
22 |
+
selected_indices = torch.randperm(len(dataset))[:subset_size]
|
23 |
+
if num_data_repeat is not None:
|
24 |
+
selected_indices = selected_indices.repeat(num_data_repeat)
|
25 |
+
return Subset(dataset, selected_indices)
|
26 |
+
|
27 |
+
class LoaderWrapper:
|
28 |
+
"""
|
29 |
+
write a dataloader class, given total steps, recursively loading data
|
30 |
+
"""
|
31 |
+
def __init__(self, loader, total_iterations: int = None):
|
32 |
+
self.loader = loader
|
33 |
+
self.total_iterations = total_iterations if total_iterations is not None else len(loader)
|
34 |
+
|
35 |
+
def __iter__(self):
|
36 |
+
self.generator = iter(self.loader)
|
37 |
+
self.counter = 0
|
38 |
+
return self
|
39 |
+
|
40 |
+
def __next__(self):
|
41 |
+
if self.counter >= self.total_iterations:
|
42 |
+
self.counter = 0
|
43 |
+
raise StopIteration
|
44 |
+
else:
|
45 |
+
self.counter += 1
|
46 |
+
try:
|
47 |
+
return next(self.generator)
|
48 |
+
except StopIteration:
|
49 |
+
self.generator = iter(self.loader)
|
50 |
+
return next(self.generator)
|
optvq/data/dataset.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
|
3 |
+
# Copyright (c) 2024 Borui Zhang. All Rights Reserved.
|
4 |
+
# Licensed under the MIT License [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------
|
6 |
+
|
7 |
+
import os
|
8 |
+
from PIL import Image
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch.utils.data import Dataset
|
13 |
+
from torchvision import transforms
|
14 |
+
|
15 |
+
from .preprocessor import normalize_params
|
16 |
+
|
17 |
+
class ImageNetDataset(Dataset):
|
18 |
+
def __init__(self, root, transform=None, convert_to_numpy: bool = True, post_normalize: str = "plain"):
|
19 |
+
self.root = root
|
20 |
+
self.transform = transform
|
21 |
+
self.convert_to_numpy = convert_to_numpy
|
22 |
+
self.post_normalize = transforms.Normalize(
|
23 |
+
**normalize_params[post_normalize]
|
24 |
+
)
|
25 |
+
|
26 |
+
# find classes
|
27 |
+
classes = sorted(entry.name for entry in os.scandir(root) if entry.is_dir())
|
28 |
+
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
|
29 |
+
|
30 |
+
# make dataset
|
31 |
+
self.samples = []
|
32 |
+
self.extensions = []
|
33 |
+
for target_class in sorted(class_to_idx.keys()):
|
34 |
+
class_idx = class_to_idx[target_class]
|
35 |
+
target_dir = os.path.join(root, target_class)
|
36 |
+
if not os.path.isdir(target_dir):
|
37 |
+
continue
|
38 |
+
for fname in sorted(os.listdir(target_dir)):
|
39 |
+
path = os.path.join(target_dir, fname)
|
40 |
+
item = (path, class_idx)
|
41 |
+
self.samples.append(item)
|
42 |
+
ext = path.split(".")[-1]
|
43 |
+
if ext not in self.extensions:
|
44 |
+
self.extensions.append(ext)
|
45 |
+
|
46 |
+
def __len__(self):
|
47 |
+
return len(self.samples)
|
48 |
+
|
49 |
+
def __getitem__(self, index):
|
50 |
+
path, label = self.samples[index]
|
51 |
+
image = Image.open(path)
|
52 |
+
if not image.mode == "RGB":
|
53 |
+
image = image.convert("RGB")
|
54 |
+
if self.convert_to_numpy:
|
55 |
+
image = np.array(image).astype("uint8")
|
56 |
+
# image augmentation
|
57 |
+
image = self.transform(image=image)["image"]
|
58 |
+
# to tensor and normalize
|
59 |
+
image = (image / 255).astype(np.float32)
|
60 |
+
image = torch.from_numpy(image).permute(2, 0, 1)
|
61 |
+
image = self.post_normalize(image)
|
62 |
+
return image, label
|
optvq/data/preprocessor.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
|
3 |
+
# Copyright (c) 2024 Borui Zhang. All Rights Reserved.
|
4 |
+
# Licensed under the MIT License [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------
|
6 |
+
|
7 |
+
from typing import Optional
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
from torchvision import transforms
|
11 |
+
import albumentations as A
|
12 |
+
|
13 |
+
BICUBIC = transforms.InterpolationMode.BICUBIC
|
14 |
+
|
15 |
+
normalize_params = {
|
16 |
+
"plain": {"mean": (0.5,), "std": (0.5,)},
|
17 |
+
"cnn": {"mean": (0.485, 0.456, 0.406), "std": (0.229, 0.224, 0.225)},
|
18 |
+
"clip": {"mean": (0.48145466, 0.4578275, 0.40821073), "std": (0.26862954, 0.26130258, 0.27577711)}
|
19 |
+
}
|
20 |
+
|
21 |
+
recover_map_dict = {
|
22 |
+
"plain": transforms.Normalize(
|
23 |
+
mean=(-1,), std=(2,)
|
24 |
+
),
|
25 |
+
"cnn": transforms.Normalize(
|
26 |
+
mean=(-0.485/0.229, -0.456/0.224, -0.406/0.225),
|
27 |
+
std=(1/0.229, 1/0.224, 1/0.225)
|
28 |
+
),
|
29 |
+
"clip": transforms.Normalize(
|
30 |
+
mean=(-0.48145466/0.26862954, -0.4578275/0.26130258, -0.40821073/0.27577711),
|
31 |
+
std=(1/0.26862954, 1/0.26130258, 1/0.27577711)
|
32 |
+
)
|
33 |
+
}
|
34 |
+
|
35 |
+
def get_recover_map(name: str):
|
36 |
+
return recover_map_dict[name]
|
37 |
+
|
38 |
+
###########################################
|
39 |
+
# Preprocessor
|
40 |
+
###########################################
|
41 |
+
|
42 |
+
def plain_preprocessor(resize: Optional[int] = 32):
|
43 |
+
return transforms.Compose([
|
44 |
+
transforms.ToTensor(),
|
45 |
+
transforms.Normalize((0.5,), (0.5,)),
|
46 |
+
transforms.Resize(resize),
|
47 |
+
])
|
48 |
+
|
49 |
+
def imagenet_preprocessor(resize: Optional[int] = 256, is_train: bool = True):
|
50 |
+
if is_train:
|
51 |
+
# augmentation v1
|
52 |
+
# transform = A.Compose([
|
53 |
+
# A.SmallestMaxSize(max_size=resize),
|
54 |
+
# A.RandomCrop(height=resize, width=resize),
|
55 |
+
# A.HorizontalFlip(p=0.5),
|
56 |
+
# ])
|
57 |
+
|
58 |
+
# augmentation v2
|
59 |
+
transform = A.Compose([
|
60 |
+
A.SmallestMaxSize(max_size=resize),
|
61 |
+
A.RandomResizedCrop(width=resize, height=resize, scale=(0.2, 1.0)),
|
62 |
+
A.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=0.8),
|
63 |
+
A.GaussianBlur(blur_limit=7, p=0.5),
|
64 |
+
A.HorizontalFlip(p=0.5),
|
65 |
+
])
|
66 |
+
else:
|
67 |
+
transform = A.Compose([
|
68 |
+
A.SmallestMaxSize(max_size=resize),
|
69 |
+
A.CenterCrop(height=resize, width=resize),
|
70 |
+
])
|
71 |
+
return transform
|
optvq/losses/__pycache__/aeloss_disc.cpython-310.pyc
ADDED
Binary file (4.27 kB). View file
|
|
optvq/losses/aeloss.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
|
3 |
+
# Copyright (c) 2024 Borui Zhang. All Rights Reserved.
|
4 |
+
# Licensed under the MIT License [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------
|
6 |
+
# Modified from [thuanz123/enhancing-transformers](https://github.com/thuanz123/enhancing-transformers)
|
7 |
+
# Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved.
|
8 |
+
# ------------------------------------------------------------------------------
|
9 |
+
# Modified from [CompVis/taming-transformers](https://github.com/CompVis/taming-transformers)
|
10 |
+
# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
|
11 |
+
# ------------------------------------------------------------------------------
|
12 |
+
|
13 |
+
from typing import Tuple
|
14 |
+
|
15 |
+
import lpips
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
|
20 |
+
class AELoss(nn.Module):
|
21 |
+
"""
|
22 |
+
Args:
|
23 |
+
loss_q_weight (float): weight for quantization loss
|
24 |
+
loss_l1_weight (float): weight for L1 loss (loglaplace)
|
25 |
+
loss_l2_weight (float): weight for L2 loss (loggaussian)
|
26 |
+
loss_p_weight (float): weight for perceptual loss
|
27 |
+
"""
|
28 |
+
def __init__(self, loss_q_weight: float = 1.0,
|
29 |
+
loss_l1_weight: float = 1.0,
|
30 |
+
loss_l2_weight: float = 1.0,
|
31 |
+
loss_p_weight: float = 1.0) -> None:
|
32 |
+
super().__init__()
|
33 |
+
self.loss_type = ["aeloss"]
|
34 |
+
|
35 |
+
self.perceptual_loss = lpips.LPIPS(net="vgg", verbose=False)
|
36 |
+
# freeze the perceptual loss
|
37 |
+
for param in self.perceptual_loss.parameters():
|
38 |
+
param.requires_grad = False
|
39 |
+
|
40 |
+
self.loss_q_weight = loss_q_weight
|
41 |
+
self.loss_l1_weight = loss_l1_weight
|
42 |
+
self.loss_l2_weight = loss_l2_weight
|
43 |
+
self.loss_p_weight = loss_p_weight
|
44 |
+
|
45 |
+
@torch.autocast(device_type="cuda", enabled=False)
|
46 |
+
def forward(self, q_loss: torch.FloatTensor,
|
47 |
+
x: torch.FloatTensor,
|
48 |
+
x_rec: torch.FloatTensor, *args, **kwargs) -> Tuple:
|
49 |
+
x = x.float()
|
50 |
+
x_rec = x_rec.float()
|
51 |
+
|
52 |
+
# compute l1 loss
|
53 |
+
loss_l1 = (x_rec - x).abs().mean() if self.loss_l1_weight > 0.0 else torch.tensor(0.0, device=x.device)
|
54 |
+
|
55 |
+
# compute l2 loss
|
56 |
+
loss_l2 = (x_rec - x).pow(2).mean() if self.loss_l2_weight > 0.0 else torch.tensor(0.0, device=x.device)
|
57 |
+
|
58 |
+
# compute perceptual loss
|
59 |
+
loss_p = self.perceptual_loss(x, x_rec).mean() if self.loss_p_weight > 0.0 else torch.tensor(0.0, device=x.device)
|
60 |
+
|
61 |
+
# compute total loss
|
62 |
+
loss = self.loss_p_weight * loss_p + \
|
63 |
+
self.loss_l1_weight * loss_l1 + \
|
64 |
+
self.loss_l2_weight * loss_l2
|
65 |
+
loss += self.loss_q_weight * q_loss
|
66 |
+
|
67 |
+
# get the log
|
68 |
+
log = {
|
69 |
+
"loss": loss.detach(),
|
70 |
+
"loss_p": loss_p.detach(),
|
71 |
+
"loss_l1": loss_l1.detach(),
|
72 |
+
"loss_l2": loss_l2.detach(),
|
73 |
+
"loss_q": q_loss.detach()
|
74 |
+
}
|
75 |
+
|
76 |
+
return loss, log
|
optvq/losses/aeloss_disc.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
|
3 |
+
# Copyright (c) 2024 Borui Zhang. All Rights Reserved.
|
4 |
+
# Licensed under the MIT License [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------
|
6 |
+
# Modified from [thuanz123/enhancing-transformers](https://github.com/thuanz123/enhancing-transformers)
|
7 |
+
# Copyright (c) 2022 Thuan H. Nguyen. All Rights Reserved.
|
8 |
+
# ------------------------------------------------------------------------------
|
9 |
+
# Modified from [CompVis/taming-transformers](https://github.com/CompVis/taming-transformers)
|
10 |
+
# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
|
11 |
+
# ------------------------------------------------------------------------------
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.nn.functional as F
|
16 |
+
import lpips
|
17 |
+
|
18 |
+
from optvq.models.discriminator import NLayerDiscriminator, weights_init
|
19 |
+
|
20 |
+
class DummyLoss(nn.Module):
|
21 |
+
def __init__(self):
|
22 |
+
super().__init__()
|
23 |
+
|
24 |
+
|
25 |
+
def hinge_d_loss(logits_real, logits_fake):
|
26 |
+
loss_real = torch.mean(F.relu(1. - logits_real))
|
27 |
+
loss_fake = torch.mean(F.relu(1. + logits_fake))
|
28 |
+
d_loss = 0.5 * (loss_real + loss_fake)
|
29 |
+
return d_loss
|
30 |
+
|
31 |
+
|
32 |
+
def vanilla_d_loss(logits_real, logits_fake):
|
33 |
+
d_loss = 0.5 * (
|
34 |
+
torch.mean(torch.nn.functional.softplus(-logits_real)) +
|
35 |
+
torch.mean(torch.nn.functional.softplus(logits_fake)))
|
36 |
+
return d_loss
|
37 |
+
|
38 |
+
|
39 |
+
class AELossWithDisc(nn.Module):
|
40 |
+
def __init__(self,
|
41 |
+
disc_start,
|
42 |
+
pixelloss_weight=1.0,
|
43 |
+
disc_in_channels=3,
|
44 |
+
disc_num_layers=3,
|
45 |
+
use_actnorm=False,
|
46 |
+
disc_ndf=64,
|
47 |
+
disc_conditional=False,
|
48 |
+
disc_loss="hinge",
|
49 |
+
loss_l1_weight: float = 1.0,
|
50 |
+
loss_l2_weight: float = 1.0,
|
51 |
+
loss_p_weight: float = 1.0,
|
52 |
+
loss_q_weight: float = 1.0,
|
53 |
+
loss_g_weight: float = 1.0,
|
54 |
+
loss_d_weight: float = 1.0
|
55 |
+
):
|
56 |
+
super(AELossWithDisc, self).__init__()
|
57 |
+
assert disc_loss in ["hinge", "vanilla"]
|
58 |
+
|
59 |
+
self.pixel_weight = pixelloss_weight
|
60 |
+
self.perceptual_loss = lpips.LPIPS(net="vgg", verbose=False).eval()
|
61 |
+
|
62 |
+
self.loss_l1_weight = loss_l1_weight
|
63 |
+
self.loss_l2_weight = loss_l2_weight
|
64 |
+
self.loss_p_weight = loss_p_weight
|
65 |
+
self.loss_q_weight = loss_q_weight
|
66 |
+
self.loss_g_weight = loss_g_weight
|
67 |
+
self.loss_d_weight = loss_d_weight
|
68 |
+
|
69 |
+
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
|
70 |
+
n_layers=disc_num_layers,
|
71 |
+
use_actnorm=use_actnorm,
|
72 |
+
ndf=disc_ndf
|
73 |
+
).apply(weights_init)
|
74 |
+
self.discriminator_iter_start = disc_start
|
75 |
+
if disc_loss == "hinge":
|
76 |
+
self.disc_loss = hinge_d_loss
|
77 |
+
elif disc_loss == "vanilla":
|
78 |
+
self.disc_loss = vanilla_d_loss
|
79 |
+
else:
|
80 |
+
raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
|
81 |
+
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
|
82 |
+
|
83 |
+
self.disc_conditional = disc_conditional
|
84 |
+
|
85 |
+
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
|
86 |
+
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
87 |
+
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
88 |
+
|
89 |
+
g_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
90 |
+
g_weight = torch.clamp(g_weight, 0.0, 1e4).detach()
|
91 |
+
g_weight = g_weight * self.loss_g_weight
|
92 |
+
|
93 |
+
# detection nan
|
94 |
+
if torch.isnan(g_weight).any():
|
95 |
+
g_weight = torch.tensor(0.0, device=g_weight.device)
|
96 |
+
return g_weight
|
97 |
+
|
98 |
+
@torch.autocast(device_type="cuda", enabled=False)
|
99 |
+
def forward(self, codebook_loss, inputs, reconstructions, mode, last_layer=None, cond=None, global_step=0):
|
100 |
+
x = inputs.contiguous().float()
|
101 |
+
x_rec = reconstructions.contiguous().float()
|
102 |
+
|
103 |
+
# compute q loss
|
104 |
+
loss_q = codebook_loss.mean()
|
105 |
+
|
106 |
+
# compute l1 loss
|
107 |
+
loss_l1 = (x_rec - x).abs().mean() if self.loss_l1_weight > 0.0 else torch.tensor(0.0, device=x.device)
|
108 |
+
|
109 |
+
# compute l2 loss
|
110 |
+
loss_l2 = (x_rec - x).pow(2).mean() if self.loss_l2_weight > 0.0 else torch.tensor(0.0, device=x.device)
|
111 |
+
|
112 |
+
# compute perceptual loss
|
113 |
+
loss_p = self.perceptual_loss(x, x_rec).mean() if self.loss_p_weight > 0.0 else torch.tensor(0.0, device=x.device)
|
114 |
+
|
115 |
+
# intigrate reconstruction loss
|
116 |
+
loss_rec = loss_l1 * self.loss_l1_weight + \
|
117 |
+
loss_l2 * self.loss_l2_weight + \
|
118 |
+
loss_p * self.loss_p_weight
|
119 |
+
|
120 |
+
# setup the factor_disc
|
121 |
+
if global_step < self.discriminator_iter_start:
|
122 |
+
factor_disc = 0.0
|
123 |
+
else:
|
124 |
+
factor_disc = 1.0
|
125 |
+
|
126 |
+
# now the GAN part
|
127 |
+
if mode == 0:
|
128 |
+
# generator update
|
129 |
+
if cond is None:
|
130 |
+
assert not self.disc_conditional
|
131 |
+
logits_fake = self.discriminator(x_rec)
|
132 |
+
else:
|
133 |
+
assert self.disc_conditional
|
134 |
+
logits_fake = self.discriminator(torch.cat((x_rec, cond), dim=1))
|
135 |
+
|
136 |
+
# compute g loss
|
137 |
+
loss_g = - logits_fake.mean()
|
138 |
+
|
139 |
+
try:
|
140 |
+
loss_g_weight = self.calculate_adaptive_weight(loss_rec, loss_g, last_layer=last_layer)
|
141 |
+
except RuntimeError:
|
142 |
+
# assert not self.training
|
143 |
+
loss_g_weight = torch.tensor(0.0)
|
144 |
+
|
145 |
+
loss = loss_g * loss_g_weight * factor_disc + \
|
146 |
+
loss_q * self.loss_q_weight + \
|
147 |
+
loss_rec
|
148 |
+
|
149 |
+
log = {"total_loss": loss.item(),
|
150 |
+
"loss_q": loss_q.item(),
|
151 |
+
"loss_rec": loss_rec.item(),
|
152 |
+
"loss_l1": loss_l1.item(),
|
153 |
+
"loss_l2": loss_l2.item(),
|
154 |
+
"loss_p": loss_p.item(),
|
155 |
+
"loss_g": loss_g.item(),
|
156 |
+
"loss_g_weight": loss_g_weight.item(),
|
157 |
+
"factor_disc": factor_disc,
|
158 |
+
}
|
159 |
+
return loss, log
|
160 |
+
|
161 |
+
if mode == 1:
|
162 |
+
# second pass for discriminator update
|
163 |
+
if cond is None:
|
164 |
+
logits_real = self.discriminator(x.detach())
|
165 |
+
logits_fake = self.discriminator(x_rec.detach())
|
166 |
+
else:
|
167 |
+
logits_real = self.discriminator(torch.cat((x.detach(), cond), dim=1))
|
168 |
+
logits_fake = self.discriminator(torch.cat((x_rec.detach(), cond), dim=1))
|
169 |
+
|
170 |
+
loss_d = self.disc_loss(logits_real, logits_fake).mean()
|
171 |
+
loss = loss_d * self.loss_d_weight
|
172 |
+
|
173 |
+
log = {"loss_d": loss_d.item(),
|
174 |
+
"logits_real": logits_real.mean().item(),
|
175 |
+
"logits_fake": logits_fake.mean().item()
|
176 |
+
}
|
177 |
+
return loss, log
|
optvq/models/__pycache__/discriminator.cpython-310.pyc
ADDED
Binary file (4.4 kB). View file
|
|
optvq/models/__pycache__/quantizer.cpython-310.pyc
ADDED
Binary file (8.15 kB). View file
|
|
optvq/models/__pycache__/vqgan.cpython-310.pyc
ADDED
Binary file (5.28 kB). View file
|
|
optvq/models/__pycache__/vqgan_hf.cpython-310.pyc
ADDED
Binary file (2.11 kB). View file
|
|
optvq/models/backbone/__pycache__/diffusion.cpython-310.pyc
ADDED
Binary file (15.5 kB). View file
|
|
optvq/models/backbone/diffusion.py
ADDED
@@ -0,0 +1,787 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# Copy from [CompVis/taming-transformers](https://github.com/CompVis/taming-transformers)
|
3 |
+
# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
|
4 |
+
# ------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
# pytorch_diffusion + derived encoder decoder
|
7 |
+
import math
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
|
13 |
+
def get_timestep_embedding(timesteps, embedding_dim):
|
14 |
+
"""
|
15 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
16 |
+
From Fairseq.
|
17 |
+
Build sinusoidal embeddings.
|
18 |
+
This matches the implementation in tensor2tensor, but differs slightly
|
19 |
+
from the description in Section 3.5 of "Attention Is All You Need".
|
20 |
+
"""
|
21 |
+
assert len(timesteps.shape) == 1
|
22 |
+
|
23 |
+
half_dim = embedding_dim // 2
|
24 |
+
emb = math.log(10000) / (half_dim - 1)
|
25 |
+
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
26 |
+
emb = emb.to(device=timesteps.device)
|
27 |
+
emb = timesteps.float()[:, None] * emb[None, :]
|
28 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
29 |
+
if embedding_dim % 2 == 1: # zero pad
|
30 |
+
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
31 |
+
return emb
|
32 |
+
|
33 |
+
|
34 |
+
def nonlinearity(x):
|
35 |
+
# swish
|
36 |
+
return x*torch.sigmoid(x)
|
37 |
+
|
38 |
+
|
39 |
+
def Normalize(in_channels):
|
40 |
+
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
41 |
+
|
42 |
+
|
43 |
+
class Upsample(nn.Module):
|
44 |
+
def __init__(self, in_channels, with_conv):
|
45 |
+
super().__init__()
|
46 |
+
self.with_conv = with_conv
|
47 |
+
if self.with_conv:
|
48 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
49 |
+
in_channels,
|
50 |
+
kernel_size=3,
|
51 |
+
stride=1,
|
52 |
+
padding=1)
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
56 |
+
if self.with_conv:
|
57 |
+
x = self.conv(x)
|
58 |
+
return x
|
59 |
+
|
60 |
+
|
61 |
+
class Downsample(nn.Module):
|
62 |
+
def __init__(self, in_channels, with_conv):
|
63 |
+
super().__init__()
|
64 |
+
self.with_conv = with_conv
|
65 |
+
if self.with_conv:
|
66 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
67 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
68 |
+
in_channels,
|
69 |
+
kernel_size=3,
|
70 |
+
stride=2,
|
71 |
+
padding=0)
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
if self.with_conv:
|
75 |
+
pad = (0,1,0,1)
|
76 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
77 |
+
x = self.conv(x)
|
78 |
+
else:
|
79 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
80 |
+
return x
|
81 |
+
|
82 |
+
|
83 |
+
class ResnetBlock(nn.Module):
|
84 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
85 |
+
dropout, temb_channels=512):
|
86 |
+
super().__init__()
|
87 |
+
self.in_channels = in_channels
|
88 |
+
out_channels = in_channels if out_channels is None else out_channels
|
89 |
+
self.out_channels = out_channels
|
90 |
+
self.use_conv_shortcut = conv_shortcut
|
91 |
+
|
92 |
+
self.norm1 = Normalize(in_channels)
|
93 |
+
self.conv1 = torch.nn.Conv2d(in_channels,
|
94 |
+
out_channels,
|
95 |
+
kernel_size=3,
|
96 |
+
stride=1,
|
97 |
+
padding=1)
|
98 |
+
if temb_channels > 0:
|
99 |
+
self.temb_proj = torch.nn.Linear(temb_channels,
|
100 |
+
out_channels)
|
101 |
+
self.norm2 = Normalize(out_channels)
|
102 |
+
self.dropout = torch.nn.Dropout(dropout)
|
103 |
+
self.conv2 = torch.nn.Conv2d(out_channels,
|
104 |
+
out_channels,
|
105 |
+
kernel_size=3,
|
106 |
+
stride=1,
|
107 |
+
padding=1)
|
108 |
+
if self.in_channels != self.out_channels:
|
109 |
+
if self.use_conv_shortcut:
|
110 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
111 |
+
out_channels,
|
112 |
+
kernel_size=3,
|
113 |
+
stride=1,
|
114 |
+
padding=1)
|
115 |
+
else:
|
116 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
117 |
+
out_channels,
|
118 |
+
kernel_size=1,
|
119 |
+
stride=1,
|
120 |
+
padding=0)
|
121 |
+
|
122 |
+
def forward(self, x, temb):
|
123 |
+
h = x
|
124 |
+
h = self.norm1(h)
|
125 |
+
h = nonlinearity(h)
|
126 |
+
h = self.conv1(h)
|
127 |
+
|
128 |
+
if temb is not None:
|
129 |
+
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
130 |
+
|
131 |
+
h = self.norm2(h)
|
132 |
+
h = nonlinearity(h)
|
133 |
+
h = self.dropout(h)
|
134 |
+
h = self.conv2(h)
|
135 |
+
|
136 |
+
if self.in_channels != self.out_channels:
|
137 |
+
if self.use_conv_shortcut:
|
138 |
+
x = self.conv_shortcut(x)
|
139 |
+
else:
|
140 |
+
x = self.nin_shortcut(x)
|
141 |
+
|
142 |
+
return x+h
|
143 |
+
|
144 |
+
|
145 |
+
class AttnBlock(nn.Module):
|
146 |
+
def __init__(self, in_channels):
|
147 |
+
super().__init__()
|
148 |
+
self.in_channels = in_channels
|
149 |
+
|
150 |
+
self.norm = Normalize(in_channels)
|
151 |
+
self.q = torch.nn.Conv2d(in_channels,
|
152 |
+
in_channels,
|
153 |
+
kernel_size=1,
|
154 |
+
stride=1,
|
155 |
+
padding=0)
|
156 |
+
self.k = torch.nn.Conv2d(in_channels,
|
157 |
+
in_channels,
|
158 |
+
kernel_size=1,
|
159 |
+
stride=1,
|
160 |
+
padding=0)
|
161 |
+
self.v = torch.nn.Conv2d(in_channels,
|
162 |
+
in_channels,
|
163 |
+
kernel_size=1,
|
164 |
+
stride=1,
|
165 |
+
padding=0)
|
166 |
+
self.proj_out = torch.nn.Conv2d(in_channels,
|
167 |
+
in_channels,
|
168 |
+
kernel_size=1,
|
169 |
+
stride=1,
|
170 |
+
padding=0)
|
171 |
+
|
172 |
+
|
173 |
+
def forward(self, x):
|
174 |
+
h_ = x
|
175 |
+
h_ = self.norm(h_)
|
176 |
+
q = self.q(h_)
|
177 |
+
k = self.k(h_)
|
178 |
+
v = self.v(h_)
|
179 |
+
|
180 |
+
# compute attention
|
181 |
+
b,c,h,w = q.shape
|
182 |
+
q = q.reshape(b,c,h*w)
|
183 |
+
q = q.permute(0,2,1) # b,hw,c
|
184 |
+
k = k.reshape(b,c,h*w) # b,c,hw
|
185 |
+
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
186 |
+
w_ = w_ * (int(c)**(-0.5))
|
187 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
188 |
+
|
189 |
+
# attend to values
|
190 |
+
v = v.reshape(b,c,h*w)
|
191 |
+
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
|
192 |
+
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
193 |
+
h_ = h_.reshape(b,c,h,w)
|
194 |
+
|
195 |
+
h_ = self.proj_out(h_)
|
196 |
+
|
197 |
+
return x+h_
|
198 |
+
|
199 |
+
|
200 |
+
class Model(nn.Module):
|
201 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
202 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
203 |
+
resolution, use_timestep=True):
|
204 |
+
super().__init__()
|
205 |
+
self.ch = ch
|
206 |
+
self.temb_ch = self.ch*4
|
207 |
+
self.num_resolutions = len(ch_mult)
|
208 |
+
self.num_res_blocks = num_res_blocks
|
209 |
+
self.resolution = resolution
|
210 |
+
self.in_channels = in_channels
|
211 |
+
|
212 |
+
self.use_timestep = use_timestep
|
213 |
+
if self.use_timestep:
|
214 |
+
# timestep embedding
|
215 |
+
self.temb = nn.Module()
|
216 |
+
self.temb.dense = nn.ModuleList([
|
217 |
+
torch.nn.Linear(self.ch,
|
218 |
+
self.temb_ch),
|
219 |
+
torch.nn.Linear(self.temb_ch,
|
220 |
+
self.temb_ch),
|
221 |
+
])
|
222 |
+
|
223 |
+
# downsampling
|
224 |
+
self.conv_in = torch.nn.Conv2d(in_channels,
|
225 |
+
self.ch,
|
226 |
+
kernel_size=3,
|
227 |
+
stride=1,
|
228 |
+
padding=1)
|
229 |
+
|
230 |
+
curr_res = resolution
|
231 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
232 |
+
self.down = nn.ModuleList()
|
233 |
+
for i_level in range(self.num_resolutions):
|
234 |
+
block = nn.ModuleList()
|
235 |
+
attn = nn.ModuleList()
|
236 |
+
block_in = ch*in_ch_mult[i_level]
|
237 |
+
block_out = ch*ch_mult[i_level]
|
238 |
+
for i_block in range(self.num_res_blocks):
|
239 |
+
block.append(ResnetBlock(in_channels=block_in,
|
240 |
+
out_channels=block_out,
|
241 |
+
temb_channels=self.temb_ch,
|
242 |
+
dropout=dropout))
|
243 |
+
block_in = block_out
|
244 |
+
if curr_res in attn_resolutions:
|
245 |
+
attn.append(AttnBlock(block_in))
|
246 |
+
down = nn.Module()
|
247 |
+
down.block = block
|
248 |
+
down.attn = attn
|
249 |
+
if i_level != self.num_resolutions-1:
|
250 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
251 |
+
curr_res = curr_res // 2
|
252 |
+
self.down.append(down)
|
253 |
+
|
254 |
+
# middle
|
255 |
+
self.mid = nn.Module()
|
256 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
257 |
+
out_channels=block_in,
|
258 |
+
temb_channels=self.temb_ch,
|
259 |
+
dropout=dropout)
|
260 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
261 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
262 |
+
out_channels=block_in,
|
263 |
+
temb_channels=self.temb_ch,
|
264 |
+
dropout=dropout)
|
265 |
+
|
266 |
+
# upsampling
|
267 |
+
self.up = nn.ModuleList()
|
268 |
+
for i_level in reversed(range(self.num_resolutions)):
|
269 |
+
block = nn.ModuleList()
|
270 |
+
attn = nn.ModuleList()
|
271 |
+
block_out = ch*ch_mult[i_level]
|
272 |
+
skip_in = ch*ch_mult[i_level]
|
273 |
+
for i_block in range(self.num_res_blocks+1):
|
274 |
+
if i_block == self.num_res_blocks:
|
275 |
+
skip_in = ch*in_ch_mult[i_level]
|
276 |
+
block.append(ResnetBlock(in_channels=block_in+skip_in,
|
277 |
+
out_channels=block_out,
|
278 |
+
temb_channels=self.temb_ch,
|
279 |
+
dropout=dropout))
|
280 |
+
block_in = block_out
|
281 |
+
if curr_res in attn_resolutions:
|
282 |
+
attn.append(AttnBlock(block_in))
|
283 |
+
up = nn.Module()
|
284 |
+
up.block = block
|
285 |
+
up.attn = attn
|
286 |
+
if i_level != 0:
|
287 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
288 |
+
curr_res = curr_res * 2
|
289 |
+
self.up.insert(0, up) # prepend to get consistent order
|
290 |
+
|
291 |
+
# end
|
292 |
+
self.norm_out = Normalize(block_in)
|
293 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
294 |
+
out_ch,
|
295 |
+
kernel_size=3,
|
296 |
+
stride=1,
|
297 |
+
padding=1)
|
298 |
+
|
299 |
+
|
300 |
+
def forward(self, x, t=None):
|
301 |
+
#assert x.shape[2] == x.shape[3] == self.resolution
|
302 |
+
|
303 |
+
if self.use_timestep:
|
304 |
+
# timestep embedding
|
305 |
+
assert t is not None
|
306 |
+
temb = get_timestep_embedding(t, self.ch)
|
307 |
+
temb = self.temb.dense[0](temb)
|
308 |
+
temb = nonlinearity(temb)
|
309 |
+
temb = self.temb.dense[1](temb)
|
310 |
+
else:
|
311 |
+
temb = None
|
312 |
+
|
313 |
+
# downsampling
|
314 |
+
hs = [self.conv_in(x)]
|
315 |
+
for i_level in range(self.num_resolutions):
|
316 |
+
for i_block in range(self.num_res_blocks):
|
317 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
318 |
+
if len(self.down[i_level].attn) > 0:
|
319 |
+
h = self.down[i_level].attn[i_block](h)
|
320 |
+
hs.append(h)
|
321 |
+
if i_level != self.num_resolutions-1:
|
322 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
323 |
+
|
324 |
+
# middle
|
325 |
+
h = hs[-1]
|
326 |
+
h = self.mid.block_1(h, temb)
|
327 |
+
h = self.mid.attn_1(h)
|
328 |
+
h = self.mid.block_2(h, temb)
|
329 |
+
|
330 |
+
# upsampling
|
331 |
+
for i_level in reversed(range(self.num_resolutions)):
|
332 |
+
for i_block in range(self.num_res_blocks+1):
|
333 |
+
h = self.up[i_level].block[i_block](
|
334 |
+
torch.cat([h, hs.pop()], dim=1), temb)
|
335 |
+
if len(self.up[i_level].attn) > 0:
|
336 |
+
h = self.up[i_level].attn[i_block](h)
|
337 |
+
if i_level != 0:
|
338 |
+
h = self.up[i_level].upsample(h)
|
339 |
+
|
340 |
+
# end
|
341 |
+
h = self.norm_out(h)
|
342 |
+
h = nonlinearity(h)
|
343 |
+
h = self.conv_out(h)
|
344 |
+
return h
|
345 |
+
|
346 |
+
|
347 |
+
class Encoder(nn.Module):
|
348 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
349 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
350 |
+
resolution, z_channels, double_z=True, **ignore_kwargs):
|
351 |
+
super().__init__()
|
352 |
+
self.ch = ch
|
353 |
+
self.temb_ch = 0
|
354 |
+
self.num_resolutions = len(ch_mult)
|
355 |
+
self.num_res_blocks = num_res_blocks
|
356 |
+
self.resolution = resolution
|
357 |
+
self.in_channels = in_channels
|
358 |
+
|
359 |
+
# downsampling
|
360 |
+
self.conv_in = torch.nn.Conv2d(in_channels,
|
361 |
+
self.ch,
|
362 |
+
kernel_size=3,
|
363 |
+
stride=1,
|
364 |
+
padding=1)
|
365 |
+
|
366 |
+
curr_res = resolution
|
367 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
368 |
+
self.down = nn.ModuleList()
|
369 |
+
for i_level in range(self.num_resolutions):
|
370 |
+
block = nn.ModuleList()
|
371 |
+
attn = nn.ModuleList()
|
372 |
+
block_in = ch*in_ch_mult[i_level]
|
373 |
+
block_out = ch*ch_mult[i_level]
|
374 |
+
for i_block in range(self.num_res_blocks):
|
375 |
+
block.append(ResnetBlock(in_channels=block_in,
|
376 |
+
out_channels=block_out,
|
377 |
+
temb_channels=self.temb_ch,
|
378 |
+
dropout=dropout))
|
379 |
+
block_in = block_out
|
380 |
+
if curr_res in attn_resolutions:
|
381 |
+
attn.append(AttnBlock(block_in))
|
382 |
+
down = nn.Module()
|
383 |
+
down.block = block
|
384 |
+
down.attn = attn
|
385 |
+
if i_level != self.num_resolutions-1:
|
386 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
387 |
+
curr_res = curr_res // 2
|
388 |
+
self.down.append(down)
|
389 |
+
|
390 |
+
# middle
|
391 |
+
self.mid = nn.Module()
|
392 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
393 |
+
out_channels=block_in,
|
394 |
+
temb_channels=self.temb_ch,
|
395 |
+
dropout=dropout)
|
396 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
397 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
398 |
+
out_channels=block_in,
|
399 |
+
temb_channels=self.temb_ch,
|
400 |
+
dropout=dropout)
|
401 |
+
|
402 |
+
# end
|
403 |
+
self.norm_out = Normalize(block_in)
|
404 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
405 |
+
2*z_channels if double_z else z_channels,
|
406 |
+
kernel_size=3,
|
407 |
+
stride=1,
|
408 |
+
padding=1)
|
409 |
+
|
410 |
+
@property
|
411 |
+
def hidden_dim(self):
|
412 |
+
return self.conv_out.out_channels
|
413 |
+
|
414 |
+
def forward(self, x):
|
415 |
+
#assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
|
416 |
+
|
417 |
+
# timestep embedding
|
418 |
+
temb = None
|
419 |
+
|
420 |
+
# downsampling
|
421 |
+
hs = [self.conv_in(x)]
|
422 |
+
for i_level in range(self.num_resolutions):
|
423 |
+
for i_block in range(self.num_res_blocks):
|
424 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
425 |
+
if len(self.down[i_level].attn) > 0:
|
426 |
+
h = self.down[i_level].attn[i_block](h)
|
427 |
+
hs.append(h)
|
428 |
+
if i_level != self.num_resolutions-1:
|
429 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
430 |
+
|
431 |
+
# middle
|
432 |
+
h = hs[-1]
|
433 |
+
h = self.mid.block_1(h, temb)
|
434 |
+
h = self.mid.attn_1(h)
|
435 |
+
h = self.mid.block_2(h, temb)
|
436 |
+
|
437 |
+
# end
|
438 |
+
h = self.norm_out(h)
|
439 |
+
h = nonlinearity(h)
|
440 |
+
h = self.conv_out(h)
|
441 |
+
return h
|
442 |
+
|
443 |
+
|
444 |
+
class Decoder(nn.Module):
|
445 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
446 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
447 |
+
resolution, z_channels, give_pre_end=False, **ignorekwargs):
|
448 |
+
super().__init__()
|
449 |
+
self.ch = ch
|
450 |
+
self.temb_ch = 0
|
451 |
+
self.num_resolutions = len(ch_mult)
|
452 |
+
self.num_res_blocks = num_res_blocks
|
453 |
+
self.resolution = resolution
|
454 |
+
self.in_channels = in_channels
|
455 |
+
self.give_pre_end = give_pre_end
|
456 |
+
|
457 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
458 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
459 |
+
block_in = ch*ch_mult[self.num_resolutions-1]
|
460 |
+
curr_res = resolution // 2**(self.num_resolutions-1)
|
461 |
+
self.z_shape = (1,z_channels,curr_res,curr_res)
|
462 |
+
print("Working with z of shape {} = {} dimensions.".format(
|
463 |
+
self.z_shape, np.prod(self.z_shape)))
|
464 |
+
|
465 |
+
# z to block_in
|
466 |
+
self.conv_in = torch.nn.Conv2d(z_channels,
|
467 |
+
block_in,
|
468 |
+
kernel_size=3,
|
469 |
+
stride=1,
|
470 |
+
padding=1)
|
471 |
+
|
472 |
+
# middle
|
473 |
+
self.mid = nn.Module()
|
474 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
475 |
+
out_channels=block_in,
|
476 |
+
temb_channels=self.temb_ch,
|
477 |
+
dropout=dropout)
|
478 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
479 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
480 |
+
out_channels=block_in,
|
481 |
+
temb_channels=self.temb_ch,
|
482 |
+
dropout=dropout)
|
483 |
+
|
484 |
+
# upsampling
|
485 |
+
self.up = nn.ModuleList()
|
486 |
+
for i_level in reversed(range(self.num_resolutions)):
|
487 |
+
block = nn.ModuleList()
|
488 |
+
attn = nn.ModuleList()
|
489 |
+
block_out = ch*ch_mult[i_level]
|
490 |
+
for i_block in range(self.num_res_blocks+1):
|
491 |
+
block.append(ResnetBlock(in_channels=block_in,
|
492 |
+
out_channels=block_out,
|
493 |
+
temb_channels=self.temb_ch,
|
494 |
+
dropout=dropout))
|
495 |
+
block_in = block_out
|
496 |
+
if curr_res in attn_resolutions:
|
497 |
+
attn.append(AttnBlock(block_in))
|
498 |
+
up = nn.Module()
|
499 |
+
up.block = block
|
500 |
+
up.attn = attn
|
501 |
+
if i_level != 0:
|
502 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
503 |
+
curr_res = curr_res * 2
|
504 |
+
self.up.insert(0, up) # prepend to get consistent order
|
505 |
+
|
506 |
+
# end
|
507 |
+
self.norm_out = Normalize(block_in)
|
508 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
509 |
+
out_ch,
|
510 |
+
kernel_size=3,
|
511 |
+
stride=1,
|
512 |
+
padding=1)
|
513 |
+
|
514 |
+
@property
|
515 |
+
def hidden_dim(self):
|
516 |
+
return self.conv_in.in_channels
|
517 |
+
|
518 |
+
def forward(self, z):
|
519 |
+
#assert z.shape[1:] == self.z_shape[1:]
|
520 |
+
self.last_z_shape = z.shape
|
521 |
+
|
522 |
+
# timestep embedding
|
523 |
+
temb = None
|
524 |
+
|
525 |
+
# z to block_in
|
526 |
+
h = self.conv_in(z)
|
527 |
+
|
528 |
+
# middle
|
529 |
+
h = self.mid.block_1(h, temb)
|
530 |
+
h = self.mid.attn_1(h)
|
531 |
+
h = self.mid.block_2(h, temb)
|
532 |
+
|
533 |
+
# upsampling
|
534 |
+
for i_level in reversed(range(self.num_resolutions)):
|
535 |
+
for i_block in range(self.num_res_blocks+1):
|
536 |
+
h = self.up[i_level].block[i_block](h, temb)
|
537 |
+
if len(self.up[i_level].attn) > 0:
|
538 |
+
h = self.up[i_level].attn[i_block](h)
|
539 |
+
if i_level != 0:
|
540 |
+
h = self.up[i_level].upsample(h)
|
541 |
+
|
542 |
+
# end
|
543 |
+
if self.give_pre_end:
|
544 |
+
return h
|
545 |
+
|
546 |
+
h = self.norm_out(h)
|
547 |
+
h = nonlinearity(h)
|
548 |
+
h = self.conv_out(h)
|
549 |
+
return h
|
550 |
+
|
551 |
+
|
552 |
+
class VUNet(nn.Module):
|
553 |
+
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
554 |
+
attn_resolutions, dropout=0.0, resamp_with_conv=True,
|
555 |
+
in_channels, c_channels,
|
556 |
+
resolution, z_channels, use_timestep=False, **ignore_kwargs):
|
557 |
+
super().__init__()
|
558 |
+
self.ch = ch
|
559 |
+
self.temb_ch = self.ch*4
|
560 |
+
self.num_resolutions = len(ch_mult)
|
561 |
+
self.num_res_blocks = num_res_blocks
|
562 |
+
self.resolution = resolution
|
563 |
+
|
564 |
+
self.use_timestep = use_timestep
|
565 |
+
if self.use_timestep:
|
566 |
+
# timestep embedding
|
567 |
+
self.temb = nn.Module()
|
568 |
+
self.temb.dense = nn.ModuleList([
|
569 |
+
torch.nn.Linear(self.ch,
|
570 |
+
self.temb_ch),
|
571 |
+
torch.nn.Linear(self.temb_ch,
|
572 |
+
self.temb_ch),
|
573 |
+
])
|
574 |
+
|
575 |
+
# downsampling
|
576 |
+
self.conv_in = torch.nn.Conv2d(c_channels,
|
577 |
+
self.ch,
|
578 |
+
kernel_size=3,
|
579 |
+
stride=1,
|
580 |
+
padding=1)
|
581 |
+
|
582 |
+
curr_res = resolution
|
583 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
584 |
+
self.down = nn.ModuleList()
|
585 |
+
for i_level in range(self.num_resolutions):
|
586 |
+
block = nn.ModuleList()
|
587 |
+
attn = nn.ModuleList()
|
588 |
+
block_in = ch*in_ch_mult[i_level]
|
589 |
+
block_out = ch*ch_mult[i_level]
|
590 |
+
for i_block in range(self.num_res_blocks):
|
591 |
+
block.append(ResnetBlock(in_channels=block_in,
|
592 |
+
out_channels=block_out,
|
593 |
+
temb_channels=self.temb_ch,
|
594 |
+
dropout=dropout))
|
595 |
+
block_in = block_out
|
596 |
+
if curr_res in attn_resolutions:
|
597 |
+
attn.append(AttnBlock(block_in))
|
598 |
+
down = nn.Module()
|
599 |
+
down.block = block
|
600 |
+
down.attn = attn
|
601 |
+
if i_level != self.num_resolutions-1:
|
602 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
603 |
+
curr_res = curr_res // 2
|
604 |
+
self.down.append(down)
|
605 |
+
|
606 |
+
self.z_in = torch.nn.Conv2d(z_channels,
|
607 |
+
block_in,
|
608 |
+
kernel_size=1,
|
609 |
+
stride=1,
|
610 |
+
padding=0)
|
611 |
+
# middle
|
612 |
+
self.mid = nn.Module()
|
613 |
+
self.mid.block_1 = ResnetBlock(in_channels=2*block_in,
|
614 |
+
out_channels=block_in,
|
615 |
+
temb_channels=self.temb_ch,
|
616 |
+
dropout=dropout)
|
617 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
618 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
619 |
+
out_channels=block_in,
|
620 |
+
temb_channels=self.temb_ch,
|
621 |
+
dropout=dropout)
|
622 |
+
|
623 |
+
# upsampling
|
624 |
+
self.up = nn.ModuleList()
|
625 |
+
for i_level in reversed(range(self.num_resolutions)):
|
626 |
+
block = nn.ModuleList()
|
627 |
+
attn = nn.ModuleList()
|
628 |
+
block_out = ch*ch_mult[i_level]
|
629 |
+
skip_in = ch*ch_mult[i_level]
|
630 |
+
for i_block in range(self.num_res_blocks+1):
|
631 |
+
if i_block == self.num_res_blocks:
|
632 |
+
skip_in = ch*in_ch_mult[i_level]
|
633 |
+
block.append(ResnetBlock(in_channels=block_in+skip_in,
|
634 |
+
out_channels=block_out,
|
635 |
+
temb_channels=self.temb_ch,
|
636 |
+
dropout=dropout))
|
637 |
+
block_in = block_out
|
638 |
+
if curr_res in attn_resolutions:
|
639 |
+
attn.append(AttnBlock(block_in))
|
640 |
+
up = nn.Module()
|
641 |
+
up.block = block
|
642 |
+
up.attn = attn
|
643 |
+
if i_level != 0:
|
644 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
645 |
+
curr_res = curr_res * 2
|
646 |
+
self.up.insert(0, up) # prepend to get consistent order
|
647 |
+
|
648 |
+
# end
|
649 |
+
self.norm_out = Normalize(block_in)
|
650 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
651 |
+
out_ch,
|
652 |
+
kernel_size=3,
|
653 |
+
stride=1,
|
654 |
+
padding=1)
|
655 |
+
|
656 |
+
|
657 |
+
def forward(self, x, z):
|
658 |
+
#assert x.shape[2] == x.shape[3] == self.resolution
|
659 |
+
|
660 |
+
if self.use_timestep:
|
661 |
+
# timestep embedding
|
662 |
+
assert t is not None
|
663 |
+
temb = get_timestep_embedding(t, self.ch)
|
664 |
+
temb = self.temb.dense[0](temb)
|
665 |
+
temb = nonlinearity(temb)
|
666 |
+
temb = self.temb.dense[1](temb)
|
667 |
+
else:
|
668 |
+
temb = None
|
669 |
+
|
670 |
+
# downsampling
|
671 |
+
hs = [self.conv_in(x)]
|
672 |
+
for i_level in range(self.num_resolutions):
|
673 |
+
for i_block in range(self.num_res_blocks):
|
674 |
+
h = self.down[i_level].block[i_block](hs[-1], temb)
|
675 |
+
if len(self.down[i_level].attn) > 0:
|
676 |
+
h = self.down[i_level].attn[i_block](h)
|
677 |
+
hs.append(h)
|
678 |
+
if i_level != self.num_resolutions-1:
|
679 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
680 |
+
|
681 |
+
# middle
|
682 |
+
h = hs[-1]
|
683 |
+
z = self.z_in(z)
|
684 |
+
h = torch.cat((h,z),dim=1)
|
685 |
+
h = self.mid.block_1(h, temb)
|
686 |
+
h = self.mid.attn_1(h)
|
687 |
+
h = self.mid.block_2(h, temb)
|
688 |
+
|
689 |
+
# upsampling
|
690 |
+
for i_level in reversed(range(self.num_resolutions)):
|
691 |
+
for i_block in range(self.num_res_blocks+1):
|
692 |
+
h = self.up[i_level].block[i_block](
|
693 |
+
torch.cat([h, hs.pop()], dim=1), temb)
|
694 |
+
if len(self.up[i_level].attn) > 0:
|
695 |
+
h = self.up[i_level].attn[i_block](h)
|
696 |
+
if i_level != 0:
|
697 |
+
h = self.up[i_level].upsample(h)
|
698 |
+
|
699 |
+
# end
|
700 |
+
h = self.norm_out(h)
|
701 |
+
h = nonlinearity(h)
|
702 |
+
h = self.conv_out(h)
|
703 |
+
return h
|
704 |
+
|
705 |
+
|
706 |
+
class SimpleDecoder(nn.Module):
|
707 |
+
def __init__(self, in_channels, out_channels, *args, **kwargs):
|
708 |
+
super().__init__()
|
709 |
+
self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
|
710 |
+
ResnetBlock(in_channels=in_channels,
|
711 |
+
out_channels=2 * in_channels,
|
712 |
+
temb_channels=0, dropout=0.0),
|
713 |
+
ResnetBlock(in_channels=2 * in_channels,
|
714 |
+
out_channels=4 * in_channels,
|
715 |
+
temb_channels=0, dropout=0.0),
|
716 |
+
ResnetBlock(in_channels=4 * in_channels,
|
717 |
+
out_channels=2 * in_channels,
|
718 |
+
temb_channels=0, dropout=0.0),
|
719 |
+
nn.Conv2d(2*in_channels, in_channels, 1),
|
720 |
+
Upsample(in_channels, with_conv=True)])
|
721 |
+
# end
|
722 |
+
self.norm_out = Normalize(in_channels)
|
723 |
+
self.conv_out = torch.nn.Conv2d(in_channels,
|
724 |
+
out_channels,
|
725 |
+
kernel_size=3,
|
726 |
+
stride=1,
|
727 |
+
padding=1)
|
728 |
+
|
729 |
+
def forward(self, x):
|
730 |
+
for i, layer in enumerate(self.model):
|
731 |
+
if i in [1,2,3]:
|
732 |
+
x = layer(x, None)
|
733 |
+
else:
|
734 |
+
x = layer(x)
|
735 |
+
|
736 |
+
h = self.norm_out(x)
|
737 |
+
h = nonlinearity(h)
|
738 |
+
x = self.conv_out(h)
|
739 |
+
return x
|
740 |
+
|
741 |
+
|
742 |
+
class UpsampleDecoder(nn.Module):
|
743 |
+
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
|
744 |
+
ch_mult=(2,2), dropout=0.0):
|
745 |
+
super().__init__()
|
746 |
+
# upsampling
|
747 |
+
self.temb_ch = 0
|
748 |
+
self.num_resolutions = len(ch_mult)
|
749 |
+
self.num_res_blocks = num_res_blocks
|
750 |
+
block_in = in_channels
|
751 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
752 |
+
self.res_blocks = nn.ModuleList()
|
753 |
+
self.upsample_blocks = nn.ModuleList()
|
754 |
+
for i_level in range(self.num_resolutions):
|
755 |
+
res_block = []
|
756 |
+
block_out = ch * ch_mult[i_level]
|
757 |
+
for i_block in range(self.num_res_blocks + 1):
|
758 |
+
res_block.append(ResnetBlock(in_channels=block_in,
|
759 |
+
out_channels=block_out,
|
760 |
+
temb_channels=self.temb_ch,
|
761 |
+
dropout=dropout))
|
762 |
+
block_in = block_out
|
763 |
+
self.res_blocks.append(nn.ModuleList(res_block))
|
764 |
+
if i_level != self.num_resolutions - 1:
|
765 |
+
self.upsample_blocks.append(Upsample(block_in, True))
|
766 |
+
curr_res = curr_res * 2
|
767 |
+
|
768 |
+
# end
|
769 |
+
self.norm_out = Normalize(block_in)
|
770 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
771 |
+
out_channels,
|
772 |
+
kernel_size=3,
|
773 |
+
stride=1,
|
774 |
+
padding=1)
|
775 |
+
|
776 |
+
def forward(self, x):
|
777 |
+
# upsampling
|
778 |
+
h = x
|
779 |
+
for k, i_level in enumerate(range(self.num_resolutions)):
|
780 |
+
for i_block in range(self.num_res_blocks + 1):
|
781 |
+
h = self.res_blocks[i_level][i_block](h, None)
|
782 |
+
if i_level != self.num_resolutions - 1:
|
783 |
+
h = self.upsample_blocks[k](h)
|
784 |
+
h = self.norm_out(h)
|
785 |
+
h = nonlinearity(h)
|
786 |
+
h = self.conv_out(h)
|
787 |
+
return h
|
optvq/models/backbone/simple_cnn.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
|
3 |
+
# Copyright (c) 2024 Borui Zhang. All Rights Reserved.
|
4 |
+
# Licensed under the MIT License [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------
|
6 |
+
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
class PlainCNNEncoder(nn.Module):
|
10 |
+
def __init__(self, in_dim: int = 3):
|
11 |
+
super(PlainCNNEncoder, self).__init__()
|
12 |
+
|
13 |
+
self.in_dim = in_dim
|
14 |
+
|
15 |
+
self.in_fc = nn.Conv2d(in_channels=in_dim, out_channels=16,
|
16 |
+
kernel_size=3, stride=1, padding=1, bias=True)
|
17 |
+
self.act0 = nn.ReLU(inplace=True)
|
18 |
+
|
19 |
+
self.down1 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=2, stride=2)
|
20 |
+
self.conv1 = nn.Conv2d(in_channels=16, out_channels=16,
|
21 |
+
kernel_size=3, stride=1, padding=1, bias=True)
|
22 |
+
self.act1 = nn.ReLU(inplace=True)
|
23 |
+
|
24 |
+
self.down2 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=2, stride=2)
|
25 |
+
self.conv2 = nn.Conv2d(in_channels=16, out_channels=32,
|
26 |
+
kernel_size=3, stride=1, padding=1, bias=True)
|
27 |
+
self.act2 = nn.ReLU(inplace=True)
|
28 |
+
|
29 |
+
self.out_fc = nn.Conv2d(in_channels=32, out_channels=32,
|
30 |
+
kernel_size=3, stride=1, padding=1, bias=True)
|
31 |
+
|
32 |
+
@property
|
33 |
+
def hidden_dim(self):
|
34 |
+
return 32
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
x = self.in_fc(x)
|
38 |
+
x = self.act0(x)
|
39 |
+
|
40 |
+
x = self.down1(x)
|
41 |
+
x = self.conv1(x)
|
42 |
+
x = self.act1(x)
|
43 |
+
|
44 |
+
x = self.down2(x)
|
45 |
+
x = self.conv2(x)
|
46 |
+
x = self.act2(x)
|
47 |
+
|
48 |
+
x = self.out_fc(x)
|
49 |
+
return x
|
50 |
+
|
51 |
+
class PlainCNNDecoder(nn.Module):
|
52 |
+
def __init__(self, out_dim: int = 3):
|
53 |
+
super(PlainCNNDecoder, self).__init__()
|
54 |
+
self.out_dim = out_dim
|
55 |
+
|
56 |
+
self.in_fc = nn.Conv2d(in_channels=32, out_channels=32,
|
57 |
+
kernel_size=3, stride=1, padding=1, bias=True)
|
58 |
+
|
59 |
+
self.act1 = nn.ReLU(inplace=True)
|
60 |
+
self.up1 = nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=2, stride=2)
|
61 |
+
self.conv1 = nn.Conv2d(in_channels=16, out_channels=16,
|
62 |
+
kernel_size=3, stride=1, padding=1, bias=True)
|
63 |
+
|
64 |
+
self.act2 = nn.ReLU(inplace=True)
|
65 |
+
self.up2 = nn.ConvTranspose2d(in_channels=16, out_channels=16, kernel_size=2, stride=2)
|
66 |
+
self.conv2 = nn.Conv2d(in_channels=16, out_channels=16,
|
67 |
+
kernel_size=3, stride=1, padding=1, bias=True)
|
68 |
+
|
69 |
+
self.act3 = nn.ReLU(inplace=True)
|
70 |
+
self.out_fc = nn.Conv2d(in_channels=16, out_channels=out_dim,
|
71 |
+
kernel_size=3, stride=1, padding=1, bias=True)
|
72 |
+
|
73 |
+
@property
|
74 |
+
def hidden_dim(self):
|
75 |
+
return 32
|
76 |
+
|
77 |
+
def forward(self, x):
|
78 |
+
x = self.in_fc(x)
|
79 |
+
|
80 |
+
x = self.act1(x)
|
81 |
+
x = self.up1(x)
|
82 |
+
x = self.conv1(x)
|
83 |
+
|
84 |
+
x = self.act2(x)
|
85 |
+
x = self.up2(x)
|
86 |
+
x = self.conv2(x)
|
87 |
+
|
88 |
+
x = self.act3(x)
|
89 |
+
x = self.out_fc(x)
|
90 |
+
return x
|
optvq/models/discriminator.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# Copy from [CompVis/taming-transformers](https://github.com/CompVis/taming-transformers)
|
3 |
+
# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
|
4 |
+
# ------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
import functools
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch
|
9 |
+
|
10 |
+
class ActNorm(nn.Module):
|
11 |
+
def __init__(self, num_features, logdet=False, affine=True,
|
12 |
+
allow_reverse_init=False):
|
13 |
+
assert affine
|
14 |
+
super().__init__()
|
15 |
+
self.logdet = logdet
|
16 |
+
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
|
17 |
+
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
|
18 |
+
self.allow_reverse_init = allow_reverse_init
|
19 |
+
|
20 |
+
self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
|
21 |
+
|
22 |
+
def initialize(self, input):
|
23 |
+
with torch.no_grad():
|
24 |
+
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
|
25 |
+
mean = (
|
26 |
+
flatten.mean(1)
|
27 |
+
.unsqueeze(1)
|
28 |
+
.unsqueeze(2)
|
29 |
+
.unsqueeze(3)
|
30 |
+
.permute(1, 0, 2, 3)
|
31 |
+
)
|
32 |
+
std = (
|
33 |
+
flatten.std(1)
|
34 |
+
.unsqueeze(1)
|
35 |
+
.unsqueeze(2)
|
36 |
+
.unsqueeze(3)
|
37 |
+
.permute(1, 0, 2, 3)
|
38 |
+
)
|
39 |
+
|
40 |
+
self.loc.data.copy_(-mean)
|
41 |
+
self.scale.data.copy_(1 / (std + 1e-6))
|
42 |
+
|
43 |
+
def forward(self, input, reverse=False):
|
44 |
+
if reverse:
|
45 |
+
return self.reverse(input)
|
46 |
+
if len(input.shape) == 2:
|
47 |
+
input = input[:,:,None,None]
|
48 |
+
squeeze = True
|
49 |
+
else:
|
50 |
+
squeeze = False
|
51 |
+
|
52 |
+
_, _, height, width = input.shape
|
53 |
+
|
54 |
+
if self.training and self.initialized.item() == 0:
|
55 |
+
self.initialize(input)
|
56 |
+
self.initialized.fill_(1)
|
57 |
+
|
58 |
+
h = self.scale * (input + self.loc)
|
59 |
+
|
60 |
+
if squeeze:
|
61 |
+
h = h.squeeze(-1).squeeze(-1)
|
62 |
+
|
63 |
+
if self.logdet:
|
64 |
+
log_abs = torch.log(torch.abs(self.scale))
|
65 |
+
logdet = height*width*torch.sum(log_abs)
|
66 |
+
logdet = logdet * torch.ones(input.shape[0]).to(input)
|
67 |
+
return h, logdet
|
68 |
+
|
69 |
+
return h
|
70 |
+
|
71 |
+
def reverse(self, output):
|
72 |
+
if self.training and self.initialized.item() == 0:
|
73 |
+
if not self.allow_reverse_init:
|
74 |
+
raise RuntimeError(
|
75 |
+
"Initializing ActNorm in reverse direction is "
|
76 |
+
"disabled by default. Use allow_reverse_init=True to enable."
|
77 |
+
)
|
78 |
+
else:
|
79 |
+
self.initialize(output)
|
80 |
+
self.initialized.fill_(1)
|
81 |
+
|
82 |
+
if len(output.shape) == 2:
|
83 |
+
output = output[:,:,None,None]
|
84 |
+
squeeze = True
|
85 |
+
else:
|
86 |
+
squeeze = False
|
87 |
+
|
88 |
+
h = output / self.scale - self.loc
|
89 |
+
|
90 |
+
if squeeze:
|
91 |
+
h = h.squeeze(-1).squeeze(-1)
|
92 |
+
return h
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
def weights_init(m):
|
97 |
+
classname = m.__class__.__name__
|
98 |
+
if classname.find('Conv') != -1:
|
99 |
+
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
100 |
+
elif classname.find('BatchNorm') != -1:
|
101 |
+
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
102 |
+
nn.init.constant_(m.bias.data, 0)
|
103 |
+
|
104 |
+
|
105 |
+
class NLayerDiscriminator(nn.Module):
|
106 |
+
"""Defines a PatchGAN discriminator as in Pix2Pix
|
107 |
+
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
108 |
+
"""
|
109 |
+
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
|
110 |
+
"""Construct a PatchGAN discriminator
|
111 |
+
Parameters:
|
112 |
+
input_nc (int) -- the number of channels in input images
|
113 |
+
ndf (int) -- the number of filters in the last conv layer
|
114 |
+
n_layers (int) -- the number of conv layers in the discriminator
|
115 |
+
norm_layer -- normalization layer
|
116 |
+
"""
|
117 |
+
super(NLayerDiscriminator, self).__init__()
|
118 |
+
if not use_actnorm:
|
119 |
+
norm_layer = nn.BatchNorm2d
|
120 |
+
else:
|
121 |
+
norm_layer = ActNorm
|
122 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
123 |
+
use_bias = norm_layer.func != nn.BatchNorm2d
|
124 |
+
else:
|
125 |
+
use_bias = norm_layer != nn.BatchNorm2d
|
126 |
+
|
127 |
+
kw = 4
|
128 |
+
padw = 1
|
129 |
+
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
130 |
+
nf_mult = 1
|
131 |
+
nf_mult_prev = 1
|
132 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
133 |
+
nf_mult_prev = nf_mult
|
134 |
+
nf_mult = min(2 ** n, 8)
|
135 |
+
sequence += [
|
136 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
137 |
+
norm_layer(ndf * nf_mult),
|
138 |
+
nn.LeakyReLU(0.2, True)
|
139 |
+
]
|
140 |
+
|
141 |
+
nf_mult_prev = nf_mult
|
142 |
+
nf_mult = min(2 ** n_layers, 8)
|
143 |
+
sequence += [
|
144 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
145 |
+
norm_layer(ndf * nf_mult),
|
146 |
+
nn.LeakyReLU(0.2, True)
|
147 |
+
]
|
148 |
+
|
149 |
+
sequence += [
|
150 |
+
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
|
151 |
+
self.main = nn.Sequential(*sequence)
|
152 |
+
|
153 |
+
def forward(self, input):
|
154 |
+
"""Standard forward."""
|
155 |
+
return self.main(input)
|
optvq/models/quantizer.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
|
3 |
+
# Copyright (c) 2024 Borui Zhang. All Rights Reserved.
|
4 |
+
# Licensed under the MIT License [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------
|
6 |
+
|
7 |
+
import math
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch import Tensor
|
14 |
+
|
15 |
+
import torch.distributed as dist
|
16 |
+
|
17 |
+
import optvq.utils.logger as L
|
18 |
+
|
19 |
+
class VectorQuantizer(nn.Module):
|
20 |
+
def __init__(self, n_e: int = 1024, e_dim: int = 128,
|
21 |
+
beta: float = 1.0, use_norm: bool = False,
|
22 |
+
use_proj: bool = True, fix_codes: bool = False,
|
23 |
+
loss_q_type: str = "ce",
|
24 |
+
num_head: int = 1,
|
25 |
+
start_quantize_steps: int = None):
|
26 |
+
super(VectorQuantizer, self).__init__()
|
27 |
+
self.n_e = n_e
|
28 |
+
self.e_dim = e_dim
|
29 |
+
self.beta = beta
|
30 |
+
self.loss_q_type = loss_q_type
|
31 |
+
self.num_head = num_head
|
32 |
+
self.start_quantize_steps = start_quantize_steps
|
33 |
+
self.code_dim = self.e_dim // self.num_head
|
34 |
+
|
35 |
+
self.norm = lambda x: F.normalize(x, p=2.0, dim=-1, eps=1e-6) if use_norm else x
|
36 |
+
assert not use_norm, f"use_norm=True is no longer supported! Because the norm operation without theorectical analysis may cause unpredictable unstability."
|
37 |
+
self.use_proj = use_proj
|
38 |
+
|
39 |
+
self.embedding = nn.Embedding(num_embeddings=n_e, embedding_dim=self.code_dim)
|
40 |
+
if use_proj:
|
41 |
+
self.proj = nn.Linear(self.code_dim, self.code_dim)
|
42 |
+
torch.nn.init.normal_(self.proj.weight, std=self.code_dim ** -0.5)
|
43 |
+
if fix_codes:
|
44 |
+
self.embedding.weight.requires_grad = False
|
45 |
+
|
46 |
+
def reshape_input(self, x: Tensor):
|
47 |
+
"""
|
48 |
+
(B, C, H, W) / (B, T, C) -> (B, T, C)
|
49 |
+
"""
|
50 |
+
if x.ndim == 4:
|
51 |
+
_, C, H, W = x.size()
|
52 |
+
x = x.permute(0, 2, 3, 1).contiguous().view(-1, H * W, C)
|
53 |
+
return x, {"size": (H, W)}
|
54 |
+
elif x.ndim == 3:
|
55 |
+
return x, None
|
56 |
+
else:
|
57 |
+
raise ValueError("Invalid input shape!")
|
58 |
+
|
59 |
+
def recover_output(self, x: Tensor, info):
|
60 |
+
if info is not None:
|
61 |
+
H, W = info["size"]
|
62 |
+
if x.ndim == 3: # features (B, T, C) -> (B, C, H, W)
|
63 |
+
C = x.size(2)
|
64 |
+
return x.view(-1, H, W, C).permute(0, 3, 1, 2).contiguous()
|
65 |
+
elif x.ndim == 2: # indices (B, T) -> (B, H, W)
|
66 |
+
return x.view(-1, H, W)
|
67 |
+
else:
|
68 |
+
raise ValueError("Invalid input shape!")
|
69 |
+
else: # features (B, T, C) or indices (B, T)
|
70 |
+
return x
|
71 |
+
|
72 |
+
def get_codebook(self, return_numpy: bool = True):
|
73 |
+
embed = self.proj(self.embedding.weight) if self.use_proj else self.embedding.weight
|
74 |
+
if return_numpy:
|
75 |
+
return embed.data.cpu().numpy()
|
76 |
+
else:
|
77 |
+
return embed.data
|
78 |
+
|
79 |
+
def quantize_input(self, query, reference):
|
80 |
+
# compute the distance matrix
|
81 |
+
query2ref = torch.cdist(query, reference, p=2.0) # (B1, B2)
|
82 |
+
|
83 |
+
# find the nearest embedding
|
84 |
+
indices = torch.argmin(query2ref, dim=-1) # (B1,)
|
85 |
+
nearest_ref = reference[indices] # (B1, C)
|
86 |
+
|
87 |
+
return indices, nearest_ref, query2ref
|
88 |
+
|
89 |
+
def compute_codebook_loss(self, query, indices, nearest_ref, beta: float, query2ref):
|
90 |
+
# compute the loss
|
91 |
+
if self.loss_q_type == "l2":
|
92 |
+
loss = torch.mean((query - nearest_ref.detach()).pow(2)) + \
|
93 |
+
torch.mean((nearest_ref - query.detach()).pow(2)) * beta
|
94 |
+
elif self.loss_q_type == "l1":
|
95 |
+
loss = torch.mean((query - nearest_ref.detach()).abs()) + \
|
96 |
+
torch.mean((nearest_ref - query.detach()).abs()) * beta
|
97 |
+
elif self.loss_q_type == "ce":
|
98 |
+
loss = F.cross_entropy(- query2ref, indices)
|
99 |
+
|
100 |
+
return loss
|
101 |
+
|
102 |
+
def compute_quantized_output(self, x, x_q):
|
103 |
+
if self.start_quantize_steps is not None:
|
104 |
+
if self.training and L.log.total_steps < self.start_quantize_steps:
|
105 |
+
L.log.add_scalar("params/quantize_ratio", 0.0)
|
106 |
+
return x
|
107 |
+
else:
|
108 |
+
L.log.add_scalar("params/quantize_ratio", 1.0)
|
109 |
+
return x + (x_q - x).detach()
|
110 |
+
else:
|
111 |
+
L.log.add_scalar("params/quantize_ratio", 1.0)
|
112 |
+
return x + (x_q - x).detach()
|
113 |
+
|
114 |
+
@torch.autocast(device_type="cuda", enabled=False)
|
115 |
+
def forward(self, x: Tensor):
|
116 |
+
"""
|
117 |
+
Quantize the input tensor x with the embedding table.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
x (Tensor): input tensor with shape (B, C, H, W) or (B, T, C)
|
121 |
+
Returns:
|
122 |
+
(tuple) containing: (x_q, loss, indices)
|
123 |
+
"""
|
124 |
+
x = x.float()
|
125 |
+
x, info = self.reshape_input(x)
|
126 |
+
B, T, C = x.size()
|
127 |
+
x = x.view(-1, C) # (B * T, C)
|
128 |
+
embed = self.proj(self.embedding.weight) if self.use_proj else self.embedding.weight
|
129 |
+
|
130 |
+
# split the x if multi-head is used
|
131 |
+
if self.num_head > 1:
|
132 |
+
x = x.view(-1, self.code_dim) # (B * T * nH, dC)
|
133 |
+
|
134 |
+
# compute the distance between x and each embedding
|
135 |
+
x, embed = self.norm(x), self.norm(embed)
|
136 |
+
|
137 |
+
# compute losses
|
138 |
+
indices, x_q, query2ref = self.quantize_input(x, embed)
|
139 |
+
loss = self.compute_codebook_loss(
|
140 |
+
query=x, indices=indices, nearest_ref=x_q,
|
141 |
+
beta=self.beta, query2ref=query2ref
|
142 |
+
)
|
143 |
+
|
144 |
+
# compute statistics
|
145 |
+
if self.training and L.GET_STATS:
|
146 |
+
with torch.no_grad():
|
147 |
+
num_unique = torch.unique(indices).size(0)
|
148 |
+
x_norm_mean = torch.mean(x.norm(dim=-1))
|
149 |
+
embed_norm_mean = torch.mean(embed.norm(dim=-1))
|
150 |
+
diff_norm_mean = torch.mean((x_q - x).norm(dim=-1))
|
151 |
+
x2e_mean = query2ref.mean()
|
152 |
+
L.log.add_scalar("params/num_unique", num_unique)
|
153 |
+
L.log.add_scalar("params/x_norm", x_norm_mean.item())
|
154 |
+
L.log.add_scalar("params/embed_norm", embed_norm_mean.item())
|
155 |
+
L.log.add_scalar("params/diff_norm", diff_norm_mean.item())
|
156 |
+
L.log.add_scalar("params/x2e_mean", x2e_mean.item())
|
157 |
+
|
158 |
+
# compute the final x_q
|
159 |
+
x_q = self.compute_quantized_output(x, x_q).view(B, T, C)
|
160 |
+
indices = indices.view(B, T, self.num_head)
|
161 |
+
|
162 |
+
# for output
|
163 |
+
x_q = self.recover_output(x_q, info)
|
164 |
+
indices = self.recover_output(indices, info)
|
165 |
+
|
166 |
+
return x_q, loss, indices
|
167 |
+
|
168 |
+
def sinkhorn(cost: Tensor, n_iters: int = 3, epsilon: float = 1, is_distributed: bool = False):
|
169 |
+
"""
|
170 |
+
Sinkhorn algorithm.
|
171 |
+
Args:
|
172 |
+
cost (Tensor): shape with (B, K)
|
173 |
+
"""
|
174 |
+
Q = torch.exp(- cost * epsilon).t() # (K, B)
|
175 |
+
if is_distributed:
|
176 |
+
B = Q.size(1) * dist.get_world_size()
|
177 |
+
else:
|
178 |
+
B = Q.size(1)
|
179 |
+
K = Q.size(0)
|
180 |
+
|
181 |
+
# make the matrix sums to 1
|
182 |
+
sum_Q = torch.sum(Q)
|
183 |
+
if is_distributed:
|
184 |
+
dist.all_reduce(sum_Q)
|
185 |
+
Q /= (sum_Q + 1e-8)
|
186 |
+
|
187 |
+
for _ in range(n_iters):
|
188 |
+
# normalize each row: total weight per prototype must be 1/K
|
189 |
+
sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
|
190 |
+
if is_distributed:
|
191 |
+
dist.all_reduce(sum_of_rows)
|
192 |
+
Q /= (sum_of_rows + 1e-8)
|
193 |
+
Q /= K
|
194 |
+
|
195 |
+
# normalize each column: total weight per sample must be 1/B
|
196 |
+
Q /= (torch.sum(Q, dim=0, keepdim=True) + 1e-8)
|
197 |
+
Q /= B
|
198 |
+
|
199 |
+
Q *= B # the columns must sum to 1 so that Q is an assignment
|
200 |
+
return Q.t() # (B, K)
|
201 |
+
|
202 |
+
class VectorQuantizerSinkhorn(VectorQuantizer):
|
203 |
+
def __init__(self, epsilon: float = 10.0, n_iters: int = 5,
|
204 |
+
normalize_mode: str = "all", use_prob: bool = True,
|
205 |
+
*args, **kwargs):
|
206 |
+
super(VectorQuantizerSinkhorn, self).__init__(*args, **kwargs)
|
207 |
+
self.epsilon = epsilon
|
208 |
+
self.n_iters = n_iters
|
209 |
+
self.normalize_mode = normalize_mode
|
210 |
+
self.use_prob = use_prob
|
211 |
+
|
212 |
+
def normalize(self, A, dim, mode="all"):
|
213 |
+
if mode == "all":
|
214 |
+
A = (A - A.mean()) / (A.std() + 1e-6)
|
215 |
+
A = A - A.min()
|
216 |
+
elif mode == "dim":
|
217 |
+
A = A / math.sqrt(dim)
|
218 |
+
elif mode == "null":
|
219 |
+
pass
|
220 |
+
return A
|
221 |
+
|
222 |
+
def quantize_input(self, query, reference):
|
223 |
+
# compute the distance matrix
|
224 |
+
query2ref = torch.cdist(query, reference, p=2.0) # (B1, B2)
|
225 |
+
|
226 |
+
# compute the assignment matrix
|
227 |
+
with torch.no_grad():
|
228 |
+
is_distributed = dist.is_initialized() and dist.get_world_size() > 1
|
229 |
+
normalized_cost = self.normalize(query2ref, dim=reference.size(1), mode=self.normalize_mode)
|
230 |
+
Q = sinkhorn(normalized_cost, n_iters=self.n_iters, epsilon=self.epsilon, is_distributed=is_distributed)
|
231 |
+
|
232 |
+
if self.use_prob:
|
233 |
+
# avoid the zero value problem
|
234 |
+
max_q_id = torch.argmax(Q, dim=-1)
|
235 |
+
Q[torch.arange(Q.size(0)), max_q_id] += 1e-8
|
236 |
+
indices = torch.multinomial(Q, num_samples=1).squeeze()
|
237 |
+
else:
|
238 |
+
indices = torch.argmax(Q, dim=-1)
|
239 |
+
nearest_ref = reference[indices]
|
240 |
+
|
241 |
+
if self.training and L.GET_STATS:
|
242 |
+
if L.log.total_steps % 1000 == 0:
|
243 |
+
L.log.add_histogram("params/normalized_cost", normalized_cost)
|
244 |
+
|
245 |
+
return indices, nearest_ref, query2ref
|
246 |
+
|
247 |
+
class Identity(VectorQuantizer):
|
248 |
+
@torch.autocast(device_type="cuda", enabled=False)
|
249 |
+
def forward(self, x: Tensor):
|
250 |
+
x = x.float()
|
251 |
+
loss_q = torch.tensor(0.0, device=x.device, dtype=x.dtype)
|
252 |
+
|
253 |
+
# compute statistics
|
254 |
+
if self.training and L.GET_STATS:
|
255 |
+
with torch.no_grad():
|
256 |
+
x_flatten, _ = self.reshape_input(x)
|
257 |
+
x_norm_mean = torch.mean(x_flatten.norm(dim=-1))
|
258 |
+
L.log.add_scalar("params/x_norm", x_norm_mean.item())
|
259 |
+
|
260 |
+
return x, loss_q, None
|
optvq/models/vqgan.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
|
3 |
+
# Copyright (c) 2024 Borui Zhang. All Rights Reserved.
|
4 |
+
# Licensed under the MIT License [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------
|
6 |
+
# Modified from [CompVis/taming-transformers](https://github.com/CompVis/taming-transformers)
|
7 |
+
# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
|
8 |
+
# ------------------------------------------------------------------------------
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import torch.nn as nn
|
13 |
+
|
14 |
+
import optvq.utils.logger as L
|
15 |
+
|
16 |
+
class Identity(nn.Module):
|
17 |
+
def forward(self, x):
|
18 |
+
return x
|
19 |
+
|
20 |
+
class VQModel(nn.Module):
|
21 |
+
def __init__(self,
|
22 |
+
encoder: nn.Module,
|
23 |
+
decoder: nn.Module,
|
24 |
+
loss: nn.Module,
|
25 |
+
quantize: nn.Module,
|
26 |
+
ckpt_path: str = None,
|
27 |
+
ignore_keys=[],
|
28 |
+
image_key="image",
|
29 |
+
colorize_nlabels=None,
|
30 |
+
monitor=None,
|
31 |
+
use_connector: bool = True,
|
32 |
+
):
|
33 |
+
super(VQModel, self).__init__()
|
34 |
+
self.encoder = encoder
|
35 |
+
self.decoder = decoder
|
36 |
+
self.loss = loss
|
37 |
+
self.quantize = quantize
|
38 |
+
self.use_connector = use_connector
|
39 |
+
|
40 |
+
encoder_dim = self.encoder.hidden_dim
|
41 |
+
decoder_dim = self.decoder.hidden_dim
|
42 |
+
embed_dim = self.quantize.e_dim
|
43 |
+
|
44 |
+
if not use_connector:
|
45 |
+
self.quant_conv = Identity()
|
46 |
+
self.post_quant_conv = Identity()
|
47 |
+
assert encoder_dim == embed_dim, f"{encoder_dim} != {embed_dim}"
|
48 |
+
assert decoder_dim == embed_dim, f"{decoder_dim} != {embed_dim}"
|
49 |
+
else:
|
50 |
+
self.quant_conv = torch.nn.Conv2d(encoder_dim, embed_dim, 1)
|
51 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, decoder_dim, 1)
|
52 |
+
|
53 |
+
if ckpt_path is not None:
|
54 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
55 |
+
self.image_key = image_key
|
56 |
+
if colorize_nlabels is not None:
|
57 |
+
assert type(colorize_nlabels)==int
|
58 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
59 |
+
if monitor is not None:
|
60 |
+
self.monitor = monitor
|
61 |
+
|
62 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
63 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
64 |
+
keys = list(sd.keys())
|
65 |
+
for k in keys:
|
66 |
+
for ik in ignore_keys:
|
67 |
+
if k.startswith(ik):
|
68 |
+
print("Deleting key {} from state_dict.".format(k))
|
69 |
+
del sd[k]
|
70 |
+
self.load_state_dict(sd, strict=False)
|
71 |
+
print(f"Restored from {path}")
|
72 |
+
|
73 |
+
def encode(self, x):
|
74 |
+
h = self.encoder(x)
|
75 |
+
h = self.quant_conv(h)
|
76 |
+
quant, emb_loss, indices = self.quantize(h)
|
77 |
+
return quant, emb_loss, indices
|
78 |
+
|
79 |
+
def decode(self, quant):
|
80 |
+
quant = self.post_quant_conv(quant)
|
81 |
+
dec = self.decoder(quant)
|
82 |
+
return dec
|
83 |
+
|
84 |
+
def decode_code(self, code_b):
|
85 |
+
quant_b = self.quantize.embed_code(code_b)
|
86 |
+
dec = self.decode(quant_b)
|
87 |
+
return dec
|
88 |
+
|
89 |
+
def forward(self, x, mode: int = 0, global_step: int = None):
|
90 |
+
"""
|
91 |
+
Args:
|
92 |
+
x (torch.Tensor): input tensor
|
93 |
+
mode (int): 0 for autoencoder, 1 for discriminator
|
94 |
+
global_step (int): global step for adaptive discriminator weight
|
95 |
+
"""
|
96 |
+
global_step = global_step if global_step is not None else L.log.total_steps
|
97 |
+
quant, qloss, indices = self.encode(x)
|
98 |
+
xrec = self.decode(quant)
|
99 |
+
if mode == 0:
|
100 |
+
# compute the autoencoder loss
|
101 |
+
loss, log_dict = self.loss(qloss, x, xrec, mode, last_layer=self.get_last_layer(), global_step=global_step)
|
102 |
+
elif mode == 1:
|
103 |
+
# compute the discriminator loss
|
104 |
+
loss, log_dict = self.loss(qloss, x, xrec, mode, last_layer=self.get_last_layer(), global_step=global_step)
|
105 |
+
elif mode == 2:
|
106 |
+
# compute the hidden embedding
|
107 |
+
h = self.encoder(x)
|
108 |
+
h = self.quant_conv(h)
|
109 |
+
return h
|
110 |
+
return loss, log_dict, indices
|
111 |
+
|
112 |
+
def get_input(self, batch, k):
|
113 |
+
x = batch[k]
|
114 |
+
if len(x.shape) == 3:
|
115 |
+
x = x[..., None]
|
116 |
+
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
|
117 |
+
return x.float()
|
118 |
+
|
119 |
+
def get_last_layer(self):
|
120 |
+
if hasattr(self.decoder, "conv_out"):
|
121 |
+
return self.decoder.conv_out.weight
|
122 |
+
elif hasattr(self.decoder, "out_fc"):
|
123 |
+
return self.decoder.out_fc.weight
|
124 |
+
elif hasattr(self.decoder, "inv_conv"):
|
125 |
+
return self.decoder.inv_conv.weight
|
126 |
+
else:
|
127 |
+
raise NotImplementedError(f"Cannot find last layer in decoder")
|
128 |
+
|
129 |
+
def log_images(self, batch, **kwargs):
|
130 |
+
log = dict()
|
131 |
+
x = self.get_input(batch, self.image_key)
|
132 |
+
x = x.to(self.device)
|
133 |
+
xrec, _ = self(x)
|
134 |
+
if x.shape[1] > 3:
|
135 |
+
# colorize with random projection
|
136 |
+
assert xrec.shape[1] > 3
|
137 |
+
x = self.to_rgb(x)
|
138 |
+
xrec = self.to_rgb(xrec)
|
139 |
+
log["inputs"] = x
|
140 |
+
log["reconstructions"] = xrec
|
141 |
+
return log
|
142 |
+
|
143 |
+
def to_rgb(self, x):
|
144 |
+
assert self.image_key == "segmentation"
|
145 |
+
if not hasattr(self, "colorize"):
|
146 |
+
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
147 |
+
x = F.conv2d(x, weight=self.colorize)
|
148 |
+
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
149 |
+
return x
|
150 |
+
|
151 |
+
# The functions below are deprecated
|
152 |
+
|
153 |
+
def validation_step(self, batch, batch_idx):
|
154 |
+
x = self.get_input(batch, self.image_key)
|
155 |
+
xrec, qloss = self(x)
|
156 |
+
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
|
157 |
+
last_layer=self.get_last_layer(), split="val")
|
158 |
+
|
159 |
+
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
|
160 |
+
last_layer=self.get_last_layer(), split="val")
|
161 |
+
rec_loss = log_dict_ae["val/rec_loss"]
|
162 |
+
self.log("val/rec_loss", rec_loss,
|
163 |
+
prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
|
164 |
+
self.log("val/aeloss", aeloss,
|
165 |
+
prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
|
166 |
+
self.log_dict(log_dict_ae)
|
167 |
+
self.log_dict(log_dict_disc)
|
168 |
+
return self.log_dict
|
optvq/models/vqgan_hf.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
|
3 |
+
# Copyright (c) 2024 Borui Zhang. All Rights Reserved.
|
4 |
+
# Licensed under the MIT License [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------
|
6 |
+
|
7 |
+
# Convert a Pytorch model to a Hugging Face model
|
8 |
+
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
from huggingface_hub import PyTorchModelHubMixin
|
12 |
+
|
13 |
+
from optvq.models.backbone.diffusion import Encoder, Decoder
|
14 |
+
from optvq.models.quantizer import VectorQuantizer, VectorQuantizerSinkhorn
|
15 |
+
from optvq.losses.aeloss_disc import AELossWithDisc
|
16 |
+
from optvq.models.vqgan import VQModel
|
17 |
+
|
18 |
+
class VQModelHF(nn.Module, PyTorchModelHubMixin):
|
19 |
+
def __init__(self,
|
20 |
+
encoder: dict = {},
|
21 |
+
decoder: dict = {},
|
22 |
+
loss: dict = {},
|
23 |
+
quantize: dict = {},
|
24 |
+
quantize_type: str = "optvq",
|
25 |
+
ckpt_path: str = None,
|
26 |
+
ignore_keys=[],
|
27 |
+
image_key="image",
|
28 |
+
colorize_nlabels=None,
|
29 |
+
monitor=None,
|
30 |
+
use_connector: bool = True,
|
31 |
+
):
|
32 |
+
super(VQModelHF, self).__init__()
|
33 |
+
encoder = Encoder(**encoder)
|
34 |
+
decoder = Decoder(**decoder)
|
35 |
+
quantizer = self.setup_quantizer(quantize, quantize_type)
|
36 |
+
loss = AELossWithDisc(**loss)
|
37 |
+
|
38 |
+
self.model = VQModel(
|
39 |
+
encoder=encoder,
|
40 |
+
decoder=decoder,
|
41 |
+
loss=loss,
|
42 |
+
quantize=quantizer,
|
43 |
+
ckpt_path=ckpt_path,
|
44 |
+
ignore_keys=ignore_keys,
|
45 |
+
image_key=image_key,
|
46 |
+
colorize_nlabels=colorize_nlabels,
|
47 |
+
monitor=monitor,
|
48 |
+
use_connector=use_connector,
|
49 |
+
)
|
50 |
+
|
51 |
+
def setup_quantizer(self, quantizer_config, quantize_type):
|
52 |
+
if quantize_type == "optvq":
|
53 |
+
quantizer = VectorQuantizerSinkhorn(**quantizer_config)
|
54 |
+
elif quantize_type == "basevq":
|
55 |
+
quantizer = VectorQuantizer(**quantizer_config)
|
56 |
+
else:
|
57 |
+
raise ValueError(f"Unknown quantizer type: {quantize_type}")
|
58 |
+
return quantizer
|
59 |
+
|
60 |
+
def encode(self, x):
|
61 |
+
return self.model.encode(x)
|
62 |
+
|
63 |
+
def decode(self, x):
|
64 |
+
return self.model.decode(x)
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
quant, *_ = self.encode(x)
|
68 |
+
rec = self.decode(quant)
|
69 |
+
return quant, rec
|
optvq/trainer/__pycache__/arguments.cpython-310.pyc
ADDED
Binary file (1.58 kB). View file
|
|
optvq/trainer/__pycache__/pipeline.cpython-310.pyc
ADDED
Binary file (9.02 kB). View file
|
|
optvq/trainer/arguments.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
|
3 |
+
# Copyright (c) 2024 Borui Zhang. All Rights Reserved.
|
4 |
+
# Licensed under the MIT License [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
def get_parser():
|
10 |
+
parser = argparse.ArgumentParser()
|
11 |
+
# arguments with high priority
|
12 |
+
parser.add_argument("--seed", type=int, default=42,
|
13 |
+
help="The random seed.")
|
14 |
+
parser.add_argument("--gpu", type=int, nargs="+", default=None,
|
15 |
+
help="The GPU ids to use.")
|
16 |
+
parser.add_argument("--is_distributed", action="store_true", default=False)
|
17 |
+
parser.add_argument("--config", type=str, default=None,
|
18 |
+
help="The path to the configuration file.")
|
19 |
+
parser.add_argument("--resume", type=str, default=None,
|
20 |
+
help="The path to the checkpoint to resume.")
|
21 |
+
parser.add_argument("--device_rank", type=int, default=0)
|
22 |
+
|
23 |
+
# arguments for the training
|
24 |
+
parser.add_argument("--log_dir", type=str, default=None,
|
25 |
+
help="The path to the log directory.")
|
26 |
+
parser.add_argument("--mode", type=str, default="train",
|
27 |
+
help="options: train, test")
|
28 |
+
parser.add_argument("--use_initiate", type=str, default=None,
|
29 |
+
help="Options: random, kmeans")
|
30 |
+
parser.add_argument("--epochs", type=int, default=None)
|
31 |
+
parser.add_argument("--enterpoint", type=str, default=None)
|
32 |
+
parser.add_argument("--code_path", type=str, default=None)
|
33 |
+
parser.add_argument("--embed_path", type=str, default=None)
|
34 |
+
|
35 |
+
# arguments for the data
|
36 |
+
parser.add_argument("--use_train_subset", type=float, default=None,
|
37 |
+
help="The size of the training subset. None means using the full training set.")
|
38 |
+
parser.add_argument("--use_train_repeat", type=int, default=None,
|
39 |
+
help="The number of times to repeat the training set.")
|
40 |
+
parser.add_argument("--use_val_subset", type=float, default=None,
|
41 |
+
help="The size of the validation subset. None means using the full validation set.")
|
42 |
+
parser.add_argument("--batch_size", type=int, default=None)
|
43 |
+
parser.add_argument("--gradient_accumulate", type=int, default=None)
|
44 |
+
|
45 |
+
# arguments for the optimizer
|
46 |
+
parser.add_argument("--lr", type=float, default=None)
|
47 |
+
parser.add_argument("--mul_lr", type=float, default=None)
|
48 |
+
|
49 |
+
# arguments for the model
|
50 |
+
parser.add_argument("--num_codes", type=int, default=None,
|
51 |
+
help="The number of codes.")
|
52 |
+
|
53 |
+
return parser
|
optvq/trainer/pipeline.py
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
|
3 |
+
# Copyright (c) 2024 Borui Zhang. All Rights Reserved.
|
4 |
+
# Licensed under the MIT License [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------
|
6 |
+
|
7 |
+
from typing import Callable
|
8 |
+
import argparse
|
9 |
+
import os
|
10 |
+
from omegaconf import OmegaConf
|
11 |
+
from functools import partial
|
12 |
+
from torchinfo import summary
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
from torch.utils.data.dataloader import DataLoader
|
17 |
+
|
18 |
+
from optvq.utils.init import initiate_from_config_recursively
|
19 |
+
from optvq.data.dataloader import maybe_get_subset
|
20 |
+
import optvq.utils.logger as L
|
21 |
+
|
22 |
+
def setup_config(opt: argparse.Namespace):
|
23 |
+
L.log.info("\n\n### Setting up the configurations. ###")
|
24 |
+
|
25 |
+
# load the config files
|
26 |
+
config = OmegaConf.load(opt.config)
|
27 |
+
|
28 |
+
# overwrite the certain arguments according to the config.args mapping
|
29 |
+
for key, value in config.args_map.items():
|
30 |
+
if hasattr(opt, key) and getattr(opt, key) is not None:
|
31 |
+
msg = f"config.{value} = opt.{key}"
|
32 |
+
L.log.info(f"Overwrite the config: {msg}")
|
33 |
+
exec(msg)
|
34 |
+
|
35 |
+
return config
|
36 |
+
|
37 |
+
def setup_dataloader(data, batch_size, is_distributed: bool = True, is_train: bool = True, num_workers: int = 8):
|
38 |
+
if is_train:
|
39 |
+
if is_distributed:
|
40 |
+
# setup the sampler
|
41 |
+
sampler = torch.utils.data.distributed.DistributedSampler(data, shuffle=True, drop_last=True)
|
42 |
+
# setup the dataloader
|
43 |
+
loader = DataLoader(
|
44 |
+
dataset=data, batch_size=batch_size, num_workers=num_workers,
|
45 |
+
drop_last=True, sampler=sampler, persistent_workers=True, pin_memory=True
|
46 |
+
)
|
47 |
+
else:
|
48 |
+
# setup the dataloader
|
49 |
+
loader = DataLoader(
|
50 |
+
dataset=data, batch_size=batch_size, num_workers=num_workers,
|
51 |
+
drop_last=True, shuffle=True, persistent_workers=True, pin_memory=True
|
52 |
+
)
|
53 |
+
else:
|
54 |
+
if is_distributed:
|
55 |
+
# setup the sampler
|
56 |
+
sampler = torch.utils.data.distributed.DistributedSampler(data, shuffle=False, drop_last=False)
|
57 |
+
# setup the dataloader
|
58 |
+
loader = DataLoader(
|
59 |
+
dataset=data, batch_size=batch_size, num_workers=num_workers,
|
60 |
+
drop_last=False, sampler=sampler, persistent_workers=True, pin_memory=True
|
61 |
+
)
|
62 |
+
else:
|
63 |
+
# setup the dataloader
|
64 |
+
loader = DataLoader(
|
65 |
+
dataset=data, batch_size=batch_size, num_workers=num_workers,
|
66 |
+
drop_last=False, shuffle=False, persistent_workers=True, pin_memory=True
|
67 |
+
)
|
68 |
+
|
69 |
+
return loader
|
70 |
+
|
71 |
+
def setup_dataset(config: OmegaConf):
|
72 |
+
L.log.info("\n\n### Setting up the datasets. ###")
|
73 |
+
|
74 |
+
# setup the training dataset
|
75 |
+
train_data = initiate_from_config_recursively(config.data.train)
|
76 |
+
if config.data.use_train_subset is not None:
|
77 |
+
train_data = maybe_get_subset(train_data, subset_size=config.data.use_train_subset, num_data_repeat=config.data.use_train_repeat)
|
78 |
+
L.log.info(f"Training dataset size: {len(train_data)}")
|
79 |
+
|
80 |
+
# setup the validation dataset
|
81 |
+
val_data = initiate_from_config_recursively(config.data.val)
|
82 |
+
if config.data.use_val_subset is not None:
|
83 |
+
val_data = maybe_get_subset(val_data, subset_size=config.data.use_val_subset)
|
84 |
+
L.log.info(f"Validation dataset size: {len(val_data)}")
|
85 |
+
|
86 |
+
return train_data, val_data
|
87 |
+
|
88 |
+
def setup_model(config: OmegaConf, device):
|
89 |
+
L.log.info("\n\n### Setting up the models. ###")
|
90 |
+
|
91 |
+
# setup the model
|
92 |
+
model = initiate_from_config_recursively(config.model.autoencoder)
|
93 |
+
if config.is_distributed:
|
94 |
+
# apply syncBN
|
95 |
+
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
96 |
+
# model to devices
|
97 |
+
model = model.to(device)
|
98 |
+
find_unused_parameters = True
|
99 |
+
model = torch.nn.parallel.DistributedDataParallel(
|
100 |
+
module=model, device_ids=[config.gpu],
|
101 |
+
find_unused_parameters=find_unused_parameters
|
102 |
+
)
|
103 |
+
model_ori = model.module
|
104 |
+
else:
|
105 |
+
model = model.to(device)
|
106 |
+
model_ori = model
|
107 |
+
|
108 |
+
input_size = config.data.train.params.transform.params.resize
|
109 |
+
in_channels = getattr(model_ori.encoder, "in_dim", 3)
|
110 |
+
sout = summary(model_ori, (1, in_channels, input_size, input_size), device="cuda", verbose=0)
|
111 |
+
L.log.info(sout)
|
112 |
+
|
113 |
+
# count the total number of parameters
|
114 |
+
for name, module in model_ori.named_children():
|
115 |
+
num_params = sum(p.numel() for p in module.parameters())
|
116 |
+
num_trainable = sum(p.numel() for p in module.parameters() if p.requires_grad)
|
117 |
+
L.log.info(f"Module: {name}, Total params: {num_params}, Trainable params: {num_trainable}")
|
118 |
+
|
119 |
+
return model
|
120 |
+
|
121 |
+
### factory functions
|
122 |
+
|
123 |
+
def get_setup_optimizers(config):
|
124 |
+
name = config.train.pipeline
|
125 |
+
func_name = "setup_optimizers_" + name
|
126 |
+
return globals()[func_name]
|
127 |
+
|
128 |
+
def get_pipeline(config):
|
129 |
+
name = config.train.pipeline
|
130 |
+
func_name = "pipeline_" + name
|
131 |
+
return globals()[func_name]
|
132 |
+
|
133 |
+
def _forward_backward(
|
134 |
+
config,
|
135 |
+
x: torch.Tensor,
|
136 |
+
forward: Callable,
|
137 |
+
model: nn.Module,
|
138 |
+
optimizer: torch.optim.Optimizer,
|
139 |
+
scheduler: torch.optim.lr_scheduler._LRScheduler,
|
140 |
+
scaler: torch.cuda.amp.GradScaler,
|
141 |
+
):
|
142 |
+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16,
|
143 |
+
enabled=config.use_amp):
|
144 |
+
# forward pass
|
145 |
+
loss, *output = forward(x)
|
146 |
+
loss_acc = loss / config.data.gradient_accumulate
|
147 |
+
scaler.scale(loss_acc).backward()
|
148 |
+
# gradient accumulate
|
149 |
+
if L.log.total_steps % config.data.gradient_accumulate == 0:
|
150 |
+
scaler.unscale_(optimizer)
|
151 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
152 |
+
scaler.step(optimizer)
|
153 |
+
optimizer.zero_grad()
|
154 |
+
scaler.update()
|
155 |
+
|
156 |
+
if scheduler is not None:
|
157 |
+
scheduler.step()
|
158 |
+
return loss, output
|
159 |
+
|
160 |
+
### autoencoder version
|
161 |
+
def _find_weight_decay_id(modules: list, params_ids: list,
|
162 |
+
include_class: tuple = (nn.Linear, nn.Conv2d,
|
163 |
+
nn.ConvTranspose2d,
|
164 |
+
nn.MultiheadAttention),
|
165 |
+
include_name: list = ["weight"]):
|
166 |
+
for mod in modules:
|
167 |
+
for sub_mod in mod.modules():
|
168 |
+
if isinstance(sub_mod, include_class):
|
169 |
+
for name, param in sub_mod.named_parameters():
|
170 |
+
if any([k in name for k in include_name]):
|
171 |
+
params_ids.append(id(param))
|
172 |
+
params_ids = list(set(params_ids))
|
173 |
+
return params_ids
|
174 |
+
|
175 |
+
def set_weight_decay(modules: list):
|
176 |
+
weight_decay_ids = _find_weight_decay_id(modules, [])
|
177 |
+
wd_params, wd_names, no_wd_params, no_wd_names = [], [], [], []
|
178 |
+
for mod in modules:
|
179 |
+
for name, param in mod.named_parameters():
|
180 |
+
if id(param) in weight_decay_ids:
|
181 |
+
wd_params.append(param)
|
182 |
+
wd_names.append(name)
|
183 |
+
else:
|
184 |
+
no_wd_params.append(param)
|
185 |
+
no_wd_names.append(name)
|
186 |
+
return wd_params, wd_names, no_wd_params, no_wd_names
|
187 |
+
|
188 |
+
def setup_optimizers_ae(config: OmegaConf, model: nn.Module, total_steps: int):
|
189 |
+
L.log.info("\n\n### Setting up the optimizers and schedulers. ###")
|
190 |
+
|
191 |
+
# compute the total batch size and the learning rate
|
192 |
+
total_batch_size = config.data.batch_size * config.world_size * config.data.gradient_accumulate
|
193 |
+
total_learning_rate = config.train.learning_rate * total_batch_size
|
194 |
+
multipled_learning_rate = total_learning_rate * config.train.mul_learning_rate
|
195 |
+
L.log.info(f"Total batch size: {total_batch_size} = {config.data.batch_size} * {config.world_size} * {config.data.gradient_accumulate}")
|
196 |
+
L.log.info(f"Total learning rate: {total_learning_rate} = {config.train.learning_rate} * {total_batch_size}")
|
197 |
+
L.log.info(f"Multipled learning rate: {multipled_learning_rate} = {total_learning_rate} * {config.train.mul_learning_rate}")
|
198 |
+
|
199 |
+
# setup the optimizers
|
200 |
+
param_group = []
|
201 |
+
## base learning rate
|
202 |
+
wd_params, wd_names, no_wd_params, no_wd_names = set_weight_decay([model.encoder, model.decoder, model.quant_conv, model.post_quant_conv])
|
203 |
+
param_group.append({
|
204 |
+
"params": wd_params, "lr": total_learning_rate, "eps": 1e-7,
|
205 |
+
"weight_decay": config.train.weight_decay, "beta": (0.9, 0.999),
|
206 |
+
})
|
207 |
+
param_group.append({
|
208 |
+
"params": no_wd_params, "lr": total_learning_rate, "eps": 1e-7,
|
209 |
+
"weight_decay": 0.0, "beta": (0.9, 0.999),
|
210 |
+
})
|
211 |
+
## multipled learning rate
|
212 |
+
wd_params, wd_names, no_wd_params, no_wd_names = set_weight_decay([model.quantize])
|
213 |
+
param_group.append({
|
214 |
+
"params": wd_params, "lr": multipled_learning_rate, "eps": 1e-7,
|
215 |
+
"weight_decay": config.train.weight_decay, "beta": (0.9, 0.999),
|
216 |
+
})
|
217 |
+
param_group.append({
|
218 |
+
"params": no_wd_params, "lr": multipled_learning_rate, "eps": 1e-7,
|
219 |
+
"weight_decay": 0.0, "beta": (0.9, 0.999),
|
220 |
+
})
|
221 |
+
|
222 |
+
optimizer_ae = torch.optim.AdamW(param_group)
|
223 |
+
optimizer_dict = {"optimizer_ae": optimizer_ae}
|
224 |
+
|
225 |
+
# setup the schedulers
|
226 |
+
scheduler_ae = torch.optim.lr_scheduler.OneCycleLR(
|
227 |
+
optimizer=optimizer_ae, max_lr=[total_learning_rate, total_learning_rate, multipled_learning_rate, multipled_learning_rate],
|
228 |
+
total_steps=total_steps, pct_start=0.01, anneal_strategy="cos"
|
229 |
+
)
|
230 |
+
scheduler_dict = {"scheduler_ae": scheduler_ae}
|
231 |
+
|
232 |
+
# setup the scalers
|
233 |
+
scaler_dict = {"scaler_ae": torch.GradScaler(enabled=config.use_amp)}
|
234 |
+
L.log.info(f"Enable AMP: {config.use_amp}")
|
235 |
+
return optimizer_dict, scheduler_dict, scaler_dict
|
236 |
+
|
237 |
+
def pipeline_ae(
|
238 |
+
config,
|
239 |
+
x: torch.Tensor,
|
240 |
+
model: nn.Module,
|
241 |
+
optimizers: dict,
|
242 |
+
schedulers: dict,
|
243 |
+
scalers: dict,
|
244 |
+
):
|
245 |
+
assert "optimizer_ae" in optimizers
|
246 |
+
assert "scheduler_ae" in schedulers
|
247 |
+
assert "scaler_ae" in scalers
|
248 |
+
|
249 |
+
optimizer = optimizers["optimizer_ae"]
|
250 |
+
scheduler = schedulers["scheduler_ae"]
|
251 |
+
scaler = scalers["scaler_ae"]
|
252 |
+
|
253 |
+
forward = partial(model, mode=0)
|
254 |
+
_, (loss_ae_dict, indices) = _forward_backward(config, x, forward, model, optimizer, scheduler, scaler)
|
255 |
+
|
256 |
+
log_per_step = loss_ae_dict
|
257 |
+
log_per_epoch = {"indices": indices}
|
258 |
+
return log_per_step, log_per_epoch
|
259 |
+
|
260 |
+
### autoencoder + disc version
|
261 |
+
|
262 |
+
def setup_optimizers_ae_disc(config: OmegaConf, model: nn.Module, total_steps: int):
|
263 |
+
L.log.info("\n\n### Setting up the optimizers and schedulers. ###")
|
264 |
+
|
265 |
+
# compute the total batch size and the learning rate
|
266 |
+
total_batch_size = config.data.batch_size * config.world_size * config.data.gradient_accumulate
|
267 |
+
total_learning_rate = config.train.learning_rate * total_batch_size
|
268 |
+
multipled_learning_rate = total_learning_rate * config.train.mul_learning_rate
|
269 |
+
L.log.info(f"Total batch size: {total_batch_size} = {config.data.batch_size} * {config.world_size} * {config.data.gradient_accumulate}")
|
270 |
+
L.log.info(f"Total learning rate: {total_learning_rate} = {config.train.learning_rate} * {total_batch_size}")
|
271 |
+
L.log.info(f"Multipled learning rate: {multipled_learning_rate} = {total_learning_rate} * {config.train.mul_learning_rate}")
|
272 |
+
|
273 |
+
# setup the optimizers
|
274 |
+
param_group = []
|
275 |
+
## base learning rate
|
276 |
+
wd_params, wd_names, no_wd_params, no_wd_names = set_weight_decay([model.encoder, model.decoder, model.quant_conv, model.post_quant_conv])
|
277 |
+
param_group.append({
|
278 |
+
"params": wd_params, "lr": total_learning_rate, "eps": 1e-7,
|
279 |
+
"weight_decay": config.train.weight_decay, "beta": (0.9, 0.999),
|
280 |
+
})
|
281 |
+
param_group.append({
|
282 |
+
"params": no_wd_params, "lr": total_learning_rate, "eps": 1e-7,
|
283 |
+
"weight_decay": 0.0, "beta": (0.9, 0.999),
|
284 |
+
})
|
285 |
+
## multipled learning rate
|
286 |
+
wd_params, wd_names, no_wd_params, no_wd_names = set_weight_decay([model.quantize])
|
287 |
+
param_group.append({
|
288 |
+
"params": wd_params, "lr": multipled_learning_rate, "eps": 1e-7,
|
289 |
+
"weight_decay": config.train.weight_decay, "beta": (0.9, 0.999),
|
290 |
+
})
|
291 |
+
param_group.append({
|
292 |
+
"params": no_wd_params, "lr": multipled_learning_rate, "eps": 1e-7,
|
293 |
+
"weight_decay": 0.0, "beta": (0.9, 0.999),
|
294 |
+
})
|
295 |
+
optimizer_ae = torch.optim.AdamW(param_group)
|
296 |
+
|
297 |
+
param_group = []
|
298 |
+
wd_params, wd_names, no_wd_params, no_wd_names = set_weight_decay([model.loss.discriminator])
|
299 |
+
param_group.append({
|
300 |
+
"params": wd_params, "lr": total_learning_rate, "eps": 1e-7,
|
301 |
+
"weight_decay": config.train.weight_decay, "beta": (0.9, 0.999),
|
302 |
+
})
|
303 |
+
param_group.append({
|
304 |
+
"params": no_wd_params, "lr": total_learning_rate, "eps": 1e-7,
|
305 |
+
"weight_decay": 0.0, "beta": (0.9, 0.999),
|
306 |
+
})
|
307 |
+
optimizer_disc = torch.optim.AdamW(param_group)
|
308 |
+
optimizer_dict = {"optimizer_ae": optimizer_ae, "optimizer_disc": optimizer_disc}
|
309 |
+
|
310 |
+
# setup the schedulers
|
311 |
+
scheduler_ae = torch.optim.lr_scheduler.OneCycleLR(
|
312 |
+
optimizer=optimizer_ae, max_lr=[total_learning_rate, total_learning_rate, multipled_learning_rate, multipled_learning_rate],
|
313 |
+
total_steps=total_steps, pct_start=0.01, anneal_strategy="cos"
|
314 |
+
)
|
315 |
+
scheduler_disc = torch.optim.lr_scheduler.OneCycleLR(
|
316 |
+
optimizer=optimizer_disc, max_lr=[total_learning_rate, total_learning_rate],
|
317 |
+
total_steps=total_steps, pct_start=0.01, anneal_strategy="cos"
|
318 |
+
)
|
319 |
+
scheduler_dict = {"scheduler_ae": scheduler_ae, "scheduler_disc": scheduler_disc}
|
320 |
+
|
321 |
+
# setup the scalers
|
322 |
+
scaler_dict = {"scaler_ae": torch.GradScaler(enabled=config.use_amp),
|
323 |
+
"scaler_disc": torch.GradScaler(enabled=config.use_amp)}
|
324 |
+
L.log.info(f"Enable AMP: {config.use_amp}")
|
325 |
+
return optimizer_dict, scheduler_dict, scaler_dict
|
326 |
+
|
327 |
+
def pipeline_ae_disc(
|
328 |
+
config,
|
329 |
+
x: torch.Tensor,
|
330 |
+
model: nn.Module,
|
331 |
+
optimizers: dict,
|
332 |
+
schedulers: dict,
|
333 |
+
scalers: dict,
|
334 |
+
):
|
335 |
+
# autoencoder step
|
336 |
+
assert "optimizer_ae" in optimizers
|
337 |
+
assert "scheduler_ae" in schedulers
|
338 |
+
assert "scaler_ae" in scalers
|
339 |
+
|
340 |
+
optimizer = optimizers["optimizer_ae"]
|
341 |
+
scheduler = schedulers["scheduler_ae"]
|
342 |
+
scaler = scalers["scaler_ae"]
|
343 |
+
|
344 |
+
forward = partial(model, mode=0)
|
345 |
+
_, (loss_ae_dict, indices) = _forward_backward(config, x, forward, model, optimizer, scheduler, scaler)
|
346 |
+
|
347 |
+
log_per_step = loss_ae_dict
|
348 |
+
log_per_epoch = {"indices": indices}
|
349 |
+
|
350 |
+
# discriminator step
|
351 |
+
assert "optimizer_disc" in optimizers
|
352 |
+
assert "scheduler_disc" in schedulers
|
353 |
+
assert "scaler_disc" in scalers
|
354 |
+
|
355 |
+
optimizer = optimizers["optimizer_disc"]
|
356 |
+
scheduler = schedulers["scheduler_disc"]
|
357 |
+
scaler = scalers["scaler_disc"]
|
358 |
+
|
359 |
+
forward = partial(model, mode=1)
|
360 |
+
_, (loss_disc_dict, _) = _forward_backward(config, x, forward, model, optimizer, scheduler, scaler)
|
361 |
+
log_per_step.update(loss_disc_dict)
|
362 |
+
return log_per_step, log_per_epoch
|
optvq/utils/__pycache__/func.cpython-310.pyc
ADDED
Binary file (1.08 kB). View file
|
|
optvq/utils/__pycache__/init.cpython-310.pyc
ADDED
Binary file (1.25 kB). View file
|
|
optvq/utils/__pycache__/logger.cpython-310.pyc
ADDED
Binary file (14 kB). View file
|
|
optvq/utils/__pycache__/metrics.cpython-310.pyc
ADDED
Binary file (2.08 kB). View file
|
|
optvq/utils/func.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
|
3 |
+
# Copyright (c) 2024 Borui Zhang. All Rights Reserved.
|
4 |
+
# Licensed under the MIT License [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------
|
6 |
+
|
7 |
+
from typing import Tuple, Union, Iterable
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
import os
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.distributed as dist
|
13 |
+
|
14 |
+
def dist_all_gather(x):
|
15 |
+
tensor_list = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
|
16 |
+
dist.all_gather(tensor_list, x)
|
17 |
+
x = torch.cat(tensor_list, dim=0)
|
18 |
+
return x
|
19 |
+
|
20 |
+
def any_2tuple(data: Union[int, Tuple[int]]) -> Tuple[int]:
|
21 |
+
if isinstance(data, int):
|
22 |
+
return (data, data)
|
23 |
+
elif isinstance(data, Iterable):
|
24 |
+
assert len(data) == 2, "target size must be tuple of (w, h)"
|
25 |
+
return tuple(data)
|
26 |
+
else:
|
27 |
+
raise ValueError("target size must be int or tuple of (w, h)")
|
optvq/utils/init.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
|
3 |
+
# Copyright (c) 2024 Borui Zhang. All Rights Reserved.
|
4 |
+
# Licensed under the MIT License [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------
|
6 |
+
|
7 |
+
import importlib
|
8 |
+
import random
|
9 |
+
import numpy as np
|
10 |
+
from typing import Mapping, Optional
|
11 |
+
|
12 |
+
import torch
|
13 |
+
|
14 |
+
def initiate_from_config(config: Mapping):
|
15 |
+
assert "target" in config, f"Expected key `target` to initialize!"
|
16 |
+
module, cls = config["target"].rsplit(".", 1)
|
17 |
+
meta_class = getattr(importlib.import_module(module, package=None), cls)
|
18 |
+
return meta_class(**config.get("params", dict()))
|
19 |
+
|
20 |
+
def initiate_from_config_recursively(config: Mapping):
|
21 |
+
assert "target" in config, f"Expected key `target` to initialize!"
|
22 |
+
update_config = {"target": config["target"], "params": {}}
|
23 |
+
for k, v in config["params"].items():
|
24 |
+
if isinstance(v, Mapping) and "target" in v:
|
25 |
+
sub_instance = initiate_from_config_recursively(v)
|
26 |
+
update_config["params"][k] = sub_instance
|
27 |
+
else:
|
28 |
+
update_config["params"][k] = v
|
29 |
+
return initiate_from_config(update_config)
|
30 |
+
|
31 |
+
def seed_everything(seed: Optional[int] = None):
|
32 |
+
seed = int(seed)
|
33 |
+
random.seed(seed)
|
34 |
+
np.random.seed(seed)
|
35 |
+
torch.manual_seed(seed)
|
36 |
+
torch.backends.cudnn.deterministic = True
|
optvq/utils/logger.py
ADDED
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
|
3 |
+
# Copyright (c) 2024 Borui Zhang. All Rights Reserved.
|
4 |
+
# Licensed under the MIT License [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------
|
6 |
+
|
7 |
+
import time
|
8 |
+
import datetime
|
9 |
+
from typing import List
|
10 |
+
import functools
|
11 |
+
import os
|
12 |
+
from PIL import Image
|
13 |
+
from termcolor import colored
|
14 |
+
import sys
|
15 |
+
import logging
|
16 |
+
from omegaconf import OmegaConf
|
17 |
+
import json
|
18 |
+
|
19 |
+
try:
|
20 |
+
from torch.utils.tensorboard import SummaryWriter
|
21 |
+
from torch import Tensor
|
22 |
+
import torch
|
23 |
+
except:
|
24 |
+
raise ImportError("Please install torch to use this module!")
|
25 |
+
|
26 |
+
"""
|
27 |
+
NOTE: The `log` instance is a global variable, which should be imported by other modules as:
|
28 |
+
`import optvq.utils.logger as logger`
|
29 |
+
rather than
|
30 |
+
`from optvq.utils.logger import log`.
|
31 |
+
"""
|
32 |
+
|
33 |
+
def setup_printer(file_log_dir: str, use_console: bool = True):
|
34 |
+
printer = logging.getLogger("LOG")
|
35 |
+
printer.setLevel(logging.DEBUG)
|
36 |
+
printer.propagate = False
|
37 |
+
|
38 |
+
# create formatter
|
39 |
+
fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s'
|
40 |
+
color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \
|
41 |
+
colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s'
|
42 |
+
|
43 |
+
# create the console handler
|
44 |
+
if use_console:
|
45 |
+
console_handler = logging.StreamHandler(sys.stdout)
|
46 |
+
console_handler.setLevel(logging.DEBUG)
|
47 |
+
console_handler.setFormatter(
|
48 |
+
logging.Formatter(fmt=color_fmt, datefmt="%Y-%m-%d %H:%M:%S")
|
49 |
+
)
|
50 |
+
printer.addHandler(console_handler)
|
51 |
+
|
52 |
+
# create the file handler
|
53 |
+
file_handler = logging.FileHandler(os.path.join(file_log_dir, "record.txt"), mode="a")
|
54 |
+
file_handler.setLevel(logging.DEBUG)
|
55 |
+
file_handler.setFormatter(
|
56 |
+
logging.Formatter(fmt=fmt, datefmt="%Y-%m-%d %H:%M:%S")
|
57 |
+
)
|
58 |
+
printer.addHandler(file_handler)
|
59 |
+
|
60 |
+
return printer
|
61 |
+
|
62 |
+
@functools.lru_cache()
|
63 |
+
def config_loggers(log_dir: str, local_rank: int = 0, master_rank: int = 0):
|
64 |
+
global log
|
65 |
+
|
66 |
+
if local_rank == master_rank:
|
67 |
+
log = LogManager(log_dir=log_dir, main_logger=True)
|
68 |
+
else:
|
69 |
+
log = LogManager(log_dir=log_dir, main_logger=False)
|
70 |
+
|
71 |
+
class ProgressWithIndices:
|
72 |
+
def __init__(self, total: int, sep_char: str = "| ",
|
73 |
+
num_per_row: int = 4):
|
74 |
+
self.total = total
|
75 |
+
self.sep_char = sep_char
|
76 |
+
self.num_per_row = num_per_row
|
77 |
+
|
78 |
+
self.count = 0
|
79 |
+
self.start_time = time.time()
|
80 |
+
self.past_time = None
|
81 |
+
self.current_time = None
|
82 |
+
self.eta = None
|
83 |
+
self.speed = None
|
84 |
+
self.used_time = 0
|
85 |
+
|
86 |
+
def update(self):
|
87 |
+
self.count += 1
|
88 |
+
if self.count <= self.total:
|
89 |
+
self.past_time = self.current_time
|
90 |
+
self.current_time = time.time()
|
91 |
+
# compute eta
|
92 |
+
if self.past_time is not None:
|
93 |
+
self.eta = (self.total - self.count) * (self.current_time - self.past_time)
|
94 |
+
self.eta = str(datetime.timedelta(seconds=int(self.eta)))
|
95 |
+
self.speed = 1 / (self.current_time - self.past_time + 1e-8)
|
96 |
+
# compute used time
|
97 |
+
self.used_time = self.current_time - self.start_time
|
98 |
+
self.used_time = str(datetime.timedelta(seconds=int(self.used_time)))
|
99 |
+
else:
|
100 |
+
self.eta = 0
|
101 |
+
self.speed = 0
|
102 |
+
self.past_time = None
|
103 |
+
self.current_time = None
|
104 |
+
|
105 |
+
def print(self, prefix: str = "", content: str = "", ):
|
106 |
+
global log
|
107 |
+
prefix_str = f"{prefix}\t" + f"[{self.count}/{self.total} {self.used_time}/Eta:{self.eta}], Speed:{self.speed}iters/s\n"
|
108 |
+
content_list = content.split(self.sep_char)
|
109 |
+
content_list = [content.strip() for content in content_list]
|
110 |
+
content_list = [
|
111 |
+
"\t\t" + self.sep_char.join(content_list[i:i + self.num_per_row])
|
112 |
+
for i in range(0, len(content_list), self.num_per_row)
|
113 |
+
]
|
114 |
+
content = prefix_str + "\n".join(content_list)
|
115 |
+
log.info(content)
|
116 |
+
|
117 |
+
class LogManager:
|
118 |
+
"""
|
119 |
+
This class encapsulates the tensorboard writer, the statistic meters, the console printer, and the progress counters.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
log_dir (str): the parent directory to save all the logs
|
123 |
+
init_meters (List[str]): the initial meters to be shown
|
124 |
+
show_avg (bool): whether to show the average value of the meters
|
125 |
+
"""
|
126 |
+
def __init__(self, log_dir: str, init_meters: List[str] = [],
|
127 |
+
show_avg: bool = True, main_logger: bool = False):
|
128 |
+
|
129 |
+
# initiate all the directories
|
130 |
+
self.show_avg = show_avg
|
131 |
+
self.log_dir = log_dir
|
132 |
+
self.main_logger = main_logger
|
133 |
+
self.setup_dirs()
|
134 |
+
|
135 |
+
# initiate the statistic meters
|
136 |
+
self.meters = {meter: AverageMeter() for meter in init_meters}
|
137 |
+
|
138 |
+
# initiate the progress counters
|
139 |
+
self.total_steps = 0
|
140 |
+
self.total_epochs = 0
|
141 |
+
|
142 |
+
if self.main_logger:
|
143 |
+
# initiate the tensorboard writer
|
144 |
+
self.board = SummaryWriter(log_dir=self.tb_log_dir)
|
145 |
+
|
146 |
+
# initiate the console printer
|
147 |
+
self.printer = setup_printer(self.file_log_dir, use_console=True)
|
148 |
+
|
149 |
+
def state_dict(self):
|
150 |
+
return {
|
151 |
+
"total_steps": self.total_steps,
|
152 |
+
"total_epochs": self.total_epochs,
|
153 |
+
"meters": {
|
154 |
+
meter_name: meter.state_dict() for meter_name, meter in self.meters.items()
|
155 |
+
}
|
156 |
+
}
|
157 |
+
|
158 |
+
def load_state_dict(self, state_dict: dict):
|
159 |
+
self.total_steps = state_dict["total_steps"]
|
160 |
+
self.total_epochs = state_dict["total_epochs"]
|
161 |
+
for meter_name, meter_state_dict in state_dict["meters"].items():
|
162 |
+
if meter_name not in self.meters:
|
163 |
+
self.meters[meter_name] = AverageMeter()
|
164 |
+
self.meters[meter_name].load_state_dict(meter_state_dict)
|
165 |
+
|
166 |
+
### About directories
|
167 |
+
def setup_dirs(self):
|
168 |
+
"""
|
169 |
+
The structure of the log directory:
|
170 |
+
- log_dir: [tb_log, txt_log, img_log, model_log]
|
171 |
+
"""
|
172 |
+
self.tb_log_dir = os.path.join(self.log_dir, "tb_log")
|
173 |
+
# NOTE: For now, we save the txt records in the parent directory
|
174 |
+
# self.file_log_dir = os.path.join(self.log_dir, "txt_log")
|
175 |
+
self.file_log_dir = self.log_dir
|
176 |
+
self.img_log_dir = os.path.join(self.log_dir, "img_log")
|
177 |
+
|
178 |
+
self.config_path = os.path.join(self.log_dir, "config.yaml")
|
179 |
+
self.checkpoint_path = os.path.join(self.log_dir, "checkpoint.pth")
|
180 |
+
self.backup_checkpoint_path = os.path.join(self.log_dir, "checkpoint.pth")
|
181 |
+
self.save_logger_path = os.path.join(self.log_dir, "logger.json")
|
182 |
+
|
183 |
+
if self.main_logger:
|
184 |
+
os.makedirs(self.tb_log_dir, exist_ok=True)
|
185 |
+
os.makedirs(self.file_log_dir, exist_ok=True)
|
186 |
+
os.makedirs(self.img_log_dir, exist_ok=True)
|
187 |
+
|
188 |
+
### About printer
|
189 |
+
|
190 |
+
def info(self, msg, *args, **kwargs):
|
191 |
+
if self.main_logger:
|
192 |
+
self.printer.info(msg, *args, **kwargs)
|
193 |
+
|
194 |
+
def show(self, include_key: str = ""):
|
195 |
+
if isinstance(include_key, str):
|
196 |
+
include_key = [include_key]
|
197 |
+
if self.show_avg:
|
198 |
+
return "| ".join([f"{meter_name}: {meter.val:.4f}/{meter.avg:.4f}" for meter_name, meter in self.meters.items() if any([k in meter_name for k in include_key])])
|
199 |
+
else:
|
200 |
+
return "| ".join([f"{meter_name}: {meter.val:.4f}" for meter_name, meter in self.meters.items() if any([k in meter_name for k in include_key])])
|
201 |
+
|
202 |
+
### About counter
|
203 |
+
|
204 |
+
def update_steps(self):
|
205 |
+
self.total_steps += 1
|
206 |
+
return self.total_steps
|
207 |
+
|
208 |
+
def update_epochs(self):
|
209 |
+
self.total_epochs += 1
|
210 |
+
return self.total_epochs
|
211 |
+
|
212 |
+
### About tensorboard
|
213 |
+
def add_histogram(self, tag: str, values: Tensor, global_step: int = None):
|
214 |
+
if self.main_logger:
|
215 |
+
global_step = self.total_steps if global_step is None else global_step
|
216 |
+
self.board.add_histogram(tag, values, global_step)
|
217 |
+
|
218 |
+
def add_scalar(self, tag: str, scalar_value: float, global_step: int = None):
|
219 |
+
if isinstance(scalar_value, Tensor):
|
220 |
+
scalar_value = scalar_value.item()
|
221 |
+
if tag in self.meters:
|
222 |
+
cur_step = self.meters[tag].update(scalar_value)
|
223 |
+
cur_step = cur_step if global_step is None else global_step
|
224 |
+
if self.main_logger:
|
225 |
+
self.board.add_scalar(tag, scalar_value, cur_step)
|
226 |
+
else:
|
227 |
+
self.meters[tag] = AverageMeter()
|
228 |
+
cur_step = self.meters[tag].update(scalar_value)
|
229 |
+
cur_step = cur_step if global_step is None else global_step
|
230 |
+
if self.main_logger:
|
231 |
+
print(f"Create new meter: {tag}!")
|
232 |
+
self.board.add_scalar(tag, scalar_value, cur_step)
|
233 |
+
|
234 |
+
def add_scalar_dict(self, scalar_dict: dict, global_step: int = None):
|
235 |
+
for tag, scalar_value in scalar_dict.items():
|
236 |
+
self.add_scalar(tag, scalar_value, global_step)
|
237 |
+
|
238 |
+
def add_images(self, tag: str, images: Tensor, global_step: int = None):
|
239 |
+
if self.main_logger:
|
240 |
+
global_step = self.total_steps if global_step is None else global_step
|
241 |
+
self.board.add_images(tag, images, global_step, dataformats="NCHW")
|
242 |
+
|
243 |
+
### About saving and resuming
|
244 |
+
def save_configs(self, config):
|
245 |
+
if self.main_logger:
|
246 |
+
# save config as yaml file
|
247 |
+
OmegaConf.save(config, self.config_path)
|
248 |
+
self.info(f"Save config to {self.config_path}.")
|
249 |
+
|
250 |
+
# save logger
|
251 |
+
state_dict = self.state_dict()
|
252 |
+
with open(self.save_logger_path, "w") as f:
|
253 |
+
json.dump(state_dict, f)
|
254 |
+
|
255 |
+
def load_configs(self):
|
256 |
+
# load config
|
257 |
+
assert os.path.exists(self.config_path), f"Config {self.config_path} does not exist!"
|
258 |
+
config = OmegaConf.load(self.config_path)
|
259 |
+
|
260 |
+
# load logger
|
261 |
+
assert os.path.exists(self.save_logger_path), f"Logger {self.save_logger_path} does not exist!"
|
262 |
+
state_dict = json.load(open(self.save_logger_path, "r"))
|
263 |
+
self.load_state_dict(state_dict)
|
264 |
+
|
265 |
+
return config
|
266 |
+
|
267 |
+
def save_checkpoint(self, model, optimizers, schedulers, scalers, suffix: str = ""):
|
268 |
+
"""
|
269 |
+
checkpoint_dict: model, optimizer, scheduler, scalers
|
270 |
+
"""
|
271 |
+
if self.main_logger:
|
272 |
+
|
273 |
+
# save checkpoint_dict
|
274 |
+
checkpoint_dict = {
|
275 |
+
"model": model.state_dict(),
|
276 |
+
"epoch": self.total_epochs,
|
277 |
+
"step": self.total_steps
|
278 |
+
}
|
279 |
+
checkpoint_dict.update({k: v.state_dict() for k, v in optimizers.items()})
|
280 |
+
checkpoint_dict.update({k: v.state_dict() for k, v in schedulers.items() if v is not None})
|
281 |
+
checkpoint_dict.update({k: v.state_dict() for k, v in scalers.items()})
|
282 |
+
|
283 |
+
checkpoint_path = self.checkpoint_path + suffix
|
284 |
+
torch.save(checkpoint_dict, checkpoint_path)
|
285 |
+
if os.path.exists(self.backup_checkpoint_path):
|
286 |
+
os.remove(self.backup_checkpoint_path)
|
287 |
+
self.backup_checkpoint_path = checkpoint_path + f".epoch{self.total_epochs}"
|
288 |
+
torch.save(checkpoint_dict, self.backup_checkpoint_path)
|
289 |
+
|
290 |
+
self.info(f"### Epoch: {self.total_epochs}| Steps: {self.total_steps}| Save checkpoint to {checkpoint_path}.")
|
291 |
+
|
292 |
+
def load_checkpoint(self, device, model, optimizers, schedulers, scalers, resume: str = None):
|
293 |
+
resume_path = self.checkpoint_path if resume is None else resume
|
294 |
+
assert os.path.exists(resume_path), f"Resume {resume_path} does not exist!"
|
295 |
+
|
296 |
+
# load checkpoint_dict
|
297 |
+
checkpoint_dict = torch.load(resume_path, map_location=device)
|
298 |
+
model.load_state_dict(checkpoint_dict["model"])
|
299 |
+
self.total_epochs = checkpoint_dict["epoch"]
|
300 |
+
self.total_steps = checkpoint_dict["step"]
|
301 |
+
for k, v in optimizers.items():
|
302 |
+
v.load_state_dict(checkpoint_dict[k])
|
303 |
+
for k, v in schedulers.items():
|
304 |
+
v.load_state_dict(checkpoint_dict[k])
|
305 |
+
for k, v in scalers.items():
|
306 |
+
v.load_state_dict(checkpoint_dict[k])
|
307 |
+
|
308 |
+
self.info(f"### Epoch: {self.total_epochs}| Steps: {self.total_steps}| Resume checkpoint from {resume_path}.")
|
309 |
+
|
310 |
+
return self.total_epochs
|
311 |
+
|
312 |
+
class EmptyManager:
|
313 |
+
def __init__(self):
|
314 |
+
for func_name in LogManager.__dict__.keys():
|
315 |
+
if not func_name.startswith("_"):
|
316 |
+
setattr(self, func_name, lambda *args, **kwargs: print(f"Empty Manager! {func_name} is not available!"))
|
317 |
+
|
318 |
+
class AverageMeter:
|
319 |
+
def __init__(self):
|
320 |
+
self.reset()
|
321 |
+
|
322 |
+
def state_dict(self):
|
323 |
+
return {
|
324 |
+
"val": self.val,
|
325 |
+
"avg": self.avg,
|
326 |
+
"sum": self.sum,
|
327 |
+
"count": self.count,
|
328 |
+
}
|
329 |
+
|
330 |
+
def load_state_dict(self, state_dict: dict):
|
331 |
+
self.val = state_dict["val"]
|
332 |
+
self.avg = state_dict["avg"]
|
333 |
+
self.sum = state_dict["sum"]
|
334 |
+
self.count = state_dict["count"]
|
335 |
+
|
336 |
+
def reset(self):
|
337 |
+
self.val = 0
|
338 |
+
self.avg = 0
|
339 |
+
self.sum = 0
|
340 |
+
self.count = 0
|
341 |
+
return 0
|
342 |
+
|
343 |
+
def update(self, val: float, n: int = 1):
|
344 |
+
self.val = val
|
345 |
+
self.sum += val * n
|
346 |
+
self.count += n
|
347 |
+
self.avg = self.sum / self.count
|
348 |
+
return self.count
|
349 |
+
|
350 |
+
def __str__(self):
|
351 |
+
return f"{self.avg:.4f}"
|
352 |
+
|
353 |
+
def save_image(x: Tensor, save_path: str, scale_to_256: bool = True):
|
354 |
+
"""
|
355 |
+
Args:
|
356 |
+
x (tensor): default data range is [0, 1]
|
357 |
+
"""
|
358 |
+
if scale_to_256:
|
359 |
+
x = x.mul(255).clamp(0, 255)
|
360 |
+
x = x.permute(1, 2, 0).detach().cpu().numpy().astype("uint8")
|
361 |
+
img = Image.fromarray(x)
|
362 |
+
img.save(save_path)
|
363 |
+
|
364 |
+
def save_images(images_list, ids_list, meta_path):
|
365 |
+
for i, (image, id) in enumerate(zip(images_list, ids_list)):
|
366 |
+
save_path = os.path.join(meta_path, f"{id}.png")
|
367 |
+
save_image(image, save_path)
|
368 |
+
|
369 |
+
def save_images_multithread(images_list, ids_list, meta_path):
|
370 |
+
n_workers = 32
|
371 |
+
from concurrent.futures import ThreadPoolExecutor
|
372 |
+
with ThreadPoolExecutor(max_workers=n_workers) as executor:
|
373 |
+
for i in range(0, len(images_list), n_workers):
|
374 |
+
cur_images = images_list[i:(i + n_workers)]
|
375 |
+
cur_ids = ids_list[i:(i + n_workers)]
|
376 |
+
executor.submit(save_images, cur_images, cur_ids, meta_path)
|
377 |
+
|
378 |
+
def add_prefix(log_dict: dict, prefix: str):
|
379 |
+
return {
|
380 |
+
f"{prefix}/{key}": val for key, val in log_dict.items()
|
381 |
+
}
|
382 |
+
|
383 |
+
##################### GLOBAL VARIABLES #####################
|
384 |
+
log = EmptyManager()
|
385 |
+
GET_STATS: bool = (os.environ.get("ENABLE_STATS", "1") == "1")
|
386 |
+
###########################################################
|
optvq/utils/metrics.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------
|
2 |
+
# OptVQ: Preventing Local Pitfalls in Vector Quantization via Optimal Transport
|
3 |
+
# Copyright (c) 2024 Borui Zhang. All Rights Reserved.
|
4 |
+
# Licensed under the MIT License [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from pytorch_fid.inception import InceptionV3
|
12 |
+
from pytorch_fid.fid_score import calculate_frechet_distance
|
13 |
+
|
14 |
+
class FIDMetric:
|
15 |
+
def __init__(self, device, dims=2048):
|
16 |
+
self.device = device
|
17 |
+
self.num_workers = 32
|
18 |
+
|
19 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
20 |
+
self.model = InceptionV3([block_idx]).to(device)
|
21 |
+
self.model.eval()
|
22 |
+
|
23 |
+
self.reset_metrics()
|
24 |
+
|
25 |
+
def reset_metrics(self):
|
26 |
+
self.x_pred = []
|
27 |
+
self.x_rec_pred = []
|
28 |
+
|
29 |
+
@torch.no_grad()
|
30 |
+
def get_activates(self, x: torch.Tensor):
|
31 |
+
pred = self.model(x)[0]
|
32 |
+
# If model output is not scalar, apply global spatial average pooling.
|
33 |
+
# This happens if you choose a dimensionality not equal 2048.
|
34 |
+
if pred.size(2) != 1 or pred.size(3) != 1:
|
35 |
+
pred = torch.nn.functional.adaptive_avg_pool2d(pred, output_size=(1, 1))
|
36 |
+
return pred.squeeze().cpu().numpy()
|
37 |
+
|
38 |
+
def update(self, x: torch.Tensor, x_rec: torch.Tensor):
|
39 |
+
"""
|
40 |
+
Args:
|
41 |
+
x (torch.Tensor): input tensor range from 0 to 1
|
42 |
+
x_rec (torch.Tensor): reconstructed tensor range from 0 to 1
|
43 |
+
"""
|
44 |
+
self.x_pred.append(self.get_activates(x))
|
45 |
+
self.x_rec_pred.append(self.get_activates(x_rec))
|
46 |
+
|
47 |
+
def result(self):
|
48 |
+
assert len(self.x_pred) != 0, "No data to compute FID"
|
49 |
+
x = np.concatenate(self.x_pred, axis=0)
|
50 |
+
x_rec = np.concatenate(self.x_rec_pred, axis=0)
|
51 |
+
|
52 |
+
x_mu = np.mean(x, axis=0)
|
53 |
+
x_sigma = np.cov(x, rowvar=False)
|
54 |
+
|
55 |
+
x_rec_mu = np.mean(x_rec, axis=0)
|
56 |
+
x_rec_sigma = np.cov(x_rec, rowvar=False)
|
57 |
+
|
58 |
+
fid_score = calculate_frechet_distance(x_mu, x_sigma, x_rec_mu, x_rec_sigma)
|
59 |
+
self.reset_metrics()
|
60 |
+
return fid_score
|
requirements.txt
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.3.0
|
2 |
+
torchvision
|
3 |
+
torchaudio
|
4 |
+
numpy
|
5 |
+
matplotlib
|
6 |
+
scikit-learn
|
7 |
+
opencv-python
|
8 |
+
einops>=0.7.0
|
9 |
+
omegaconf>=2.3.0
|
10 |
+
lightning>=2.1.3
|
11 |
+
transformers>=4.36.2
|
12 |
+
pudb
|
13 |
+
wandb
|
14 |
+
tensorboard
|
15 |
+
termcolor
|
16 |
+
lpips==0.1.4
|
17 |
+
ftfy
|
18 |
+
imageio
|
19 |
+
imageio-ffmpeg
|
20 |
+
pyiqa
|
21 |
+
clean-fid
|
22 |
+
albumentations
|
23 |
+
pytorch-fid
|
24 |
+
torchinfo
|
25 |
+
gradio
|