Spaces:
Sleeping
Sleeping
File size: 15,269 Bytes
6c4dee3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 |
import math
from typing import List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
from torch import distributed as tdist
from torch import nn as nn
from torch.nn import functional as F
import dist
# this file only provides the VectorQuantizer2 used in VQVAE
__all__ = ["VectorQuantizer2"]
class VectorQuantizer2(nn.Module):
# VQGAN originally use beta=1.0, never tried 0.25; SD seems using 0.25
def __init__(
self,
vocab_size,
Cvae,
using_znorm,
beta: float = 0.25,
default_qresi_counts=0,
v_patch_nums=None,
quant_resi=0.5,
share_quant_resi=4, # share_quant_resi: args.qsr
):
super().__init__()
self.vocab_size: int = vocab_size
self.Cvae: int = Cvae
self.using_znorm: bool = using_znorm
self.v_patch_nums: Tuple[int] = v_patch_nums
self.quant_resi_ratio = quant_resi
if share_quant_resi == 0: # non-shared: \phi_{1 to K} for K scales
self.quant_resi = PhiNonShared(
[
(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity())
for _ in range(default_qresi_counts or len(self.v_patch_nums))
]
)
elif share_quant_resi == 1: # fully shared: only a single \phi for K scales
self.quant_resi = PhiShared(
Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()
)
else: # partially shared: \phi_{1 to share_quant_resi} for K scales
self.quant_resi = PhiPartiallyShared(
nn.ModuleList([(
Phi(Cvae, quant_resi)
if abs(quant_resi) > 1e-6
else nn.Identity()
) for _ in range(share_quant_resi)])
)
self.register_buffer(
"ema_vocab_hit_SV",
torch.full((len(self.v_patch_nums), self.vocab_size), fill_value=0.0),
)
self.record_hit = 0
self.beta: float = beta
self.embedding = nn.Embedding(self.vocab_size, self.Cvae)
def eini(self, eini):
if eini > 0:
nn.init.trunc_normal_(self.embedding.weight.data, std=eini)
elif eini < 0:
self.embedding.weight.data.uniform_(
-abs(eini) / self.vocab_size, abs(eini) / self.vocab_size
)
def extra_repr(self) -> str:
return f"{self.v_patch_nums}, znorm={self.using_znorm}, beta={self.beta} | S={len(self.v_patch_nums)}, quant_resi={self.quant_resi_ratio}"
# ===================== `forward` is only used in VAE training =====================
def forward(
self, f_BChw: torch.Tensor, ret_usages=False
) -> Tuple[torch.Tensor, List[float], torch.Tensor]:
dtype = f_BChw.dtype
if dtype != torch.float32:
f_BChw = f_BChw.float()
B, C, H, W = f_BChw.shape
f_no_grad = f_BChw.detach()
f_rest = f_no_grad.clone()
f_hat = torch.zeros_like(f_rest)
with torch.cuda.amp.autocast(enabled=False):
mean_vq_loss: torch.Tensor = 0.0
vocab_hit_V = torch.zeros(
self.vocab_size, dtype=torch.float, device=f_BChw.device
)
SN = len(self.v_patch_nums)
for si, pn in enumerate(self.v_patch_nums): # from small to large
# find the nearest embedding
if self.using_znorm:
rest_NC = (
F.interpolate(f_rest, size=(pn, pn), mode="area")
.permute(0, 2, 3, 1)
.reshape(-1, C)
if (si != SN - 1)
else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
)
rest_NC = F.normalize(rest_NC, dim=-1)
idx_N = torch.argmax(
rest_NC @ F.normalize(self.embedding.weight.data.T, dim=0),
dim=1,
)
else:
rest_NC = (
F.interpolate(f_rest, size=(pn, pn), mode="area")
.permute(0, 2, 3, 1)
.reshape(-1, C)
if (si != SN - 1)
else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
)
d_no_grad = torch.sum(
rest_NC.square(), dim=1, keepdim=True
) + torch.sum(
self.embedding.weight.data.square(), dim=1, keepdim=False
)
d_no_grad.addmm_(
rest_NC, self.embedding.weight.data.T, alpha=-2, beta=1
) # (B*h*w, vocab_size)
idx_N = torch.argmin(d_no_grad, dim=1)
hit_V = idx_N.bincount(minlength=self.vocab_size).float()
if self.training:
if dist.initialized():
handler = tdist.all_reduce(hit_V, async_op=True)
# calc loss
idx_Bhw = idx_N.view(B, pn, pn)
h_BChw = (
F.interpolate(
self.embedding(idx_Bhw).permute(0, 3, 1, 2),
size=(H, W),
mode="bicubic",
).contiguous()
if (si != SN - 1)
else self.embedding(idx_Bhw).permute(0, 3, 1, 2).contiguous()
)
h_BChw = self.quant_resi[si / (SN - 1)](h_BChw)
f_hat = f_hat + h_BChw
f_rest -= h_BChw
if self.training and dist.initialized():
handler.wait()
if self.record_hit == 0:
self.ema_vocab_hit_SV[si].copy_(hit_V)
elif self.record_hit < 100:
self.ema_vocab_hit_SV[si].mul_(0.9).add_(hit_V.mul(0.1))
else:
self.ema_vocab_hit_SV[si].mul_(0.99).add_(hit_V.mul(0.01))
self.record_hit += 1
vocab_hit_V.add_(hit_V)
mean_vq_loss += F.mse_loss(f_hat.data, f_BChw).mul_(self.beta) + F.mse_loss(f_hat, f_no_grad)
mean_vq_loss *= 1.0 / SN
f_hat = (f_hat.data - f_no_grad).add_(f_BChw)
margin = (
tdist.get_world_size()
* (f_BChw.numel() / f_BChw.shape[1])
/ self.vocab_size
* 0.08
)
# margin = pn*pn / 100
if ret_usages:
usages = [
(self.ema_vocab_hit_SV[si] >= margin).float().mean().item() * 100
for si, pn in enumerate(self.v_patch_nums)
]
else:
usages = None
return f_hat, usages, mean_vq_loss
# ===================== `forward` is only used in VAE training =====================
def embed_to_fhat(
self, ms_h_BChw: List[torch.Tensor], all_to_max_scale=True, last_one=False
) -> Union[List[torch.Tensor], torch.Tensor]:
ls_f_hat_BChw = []
B = ms_h_BChw[0].shape[0]
H = W = self.v_patch_nums[-1]
SN = len(self.v_patch_nums)
if all_to_max_scale:
f_hat = ms_h_BChw[0].new_zeros(B, self.Cvae, H, W, dtype=torch.float32)
for si, pn in enumerate(self.v_patch_nums): # from small to large
h_BChw = ms_h_BChw[si]
if si < len(self.v_patch_nums) - 1:
h_BChw = F.interpolate(h_BChw, size=(H, W), mode="bicubic")
h_BChw = self.quant_resi[si / (SN - 1)](h_BChw)
f_hat.add_(h_BChw)
if last_one:
ls_f_hat_BChw = f_hat
else:
ls_f_hat_BChw.append(f_hat.clone())
else:
# WARNING: this is not the case in VQ-VAE training or inference (we'll interpolate every token map to the max H W, like above)
# WARNING: this should only be used for experimental purpose
f_hat = ms_h_BChw[0].new_zeros(
B,
self.Cvae,
self.v_patch_nums[0],
self.v_patch_nums[0],
dtype=torch.float32,
)
for si, pn in enumerate(self.v_patch_nums): # from small to large
f_hat = F.interpolate(f_hat, size=(pn, pn), mode="bicubic")
h_BChw = self.quant_resi[si / (SN - 1)](ms_h_BChw[si])
f_hat.add_(h_BChw)
if last_one:
ls_f_hat_BChw = f_hat
else:
ls_f_hat_BChw.append(f_hat)
return ls_f_hat_BChw
def f_to_idxBl_or_fhat(
self,
f_BChw: torch.Tensor,
to_fhat: bool,
v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None,
noise_std: Optional[float] = None,
) -> List[Union[torch.Tensor, torch.LongTensor]]: # z_BChw is the feature from inp_img_no_grad
B, C, H, W = f_BChw.shape
f_no_grad = f_BChw.detach()
f_rest = f_no_grad.clone()
f_hat = torch.zeros_like(f_rest)
f_hat_or_idx_Bl: List[torch.Tensor] = []
patch_hws = [
(pn, pn) if isinstance(pn, int) else (pn[0], pn[1])
for pn in (v_patch_nums or self.v_patch_nums)
] # from small to large
assert (
patch_hws[-1][0] == H and patch_hws[-1][1] == W
), f"{patch_hws[-1]=} != ({H=}, {W=})"
SN = len(patch_hws)
for si, (ph, pw) in enumerate(patch_hws): # from small to large
# find the nearest embedding
z_NC = (
F.interpolate(f_rest, size=(ph, pw), mode="area")
.permute(0, 2, 3, 1)
.reshape(-1, C)
if (si != SN - 1)
else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
)
if noise_std is not None:
z_NC = math.sqrt(1 - noise_std ** 2) * z_NC + torch.randn_like(z_NC) * noise_std
if self.using_znorm:
z_NC = F.normalize(z_NC, dim=-1)
idx_N = torch.argmax(
z_NC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1
)
else:
d_no_grad = torch.sum(z_NC.square(), dim=1, keepdim=True) + torch.sum(
self.embedding.weight.data.square(), dim=1, keepdim=False
)
d_no_grad.addmm_(
z_NC, self.embedding.weight.data.T, alpha=-2, beta=1
) # (B*h*w, vocab_size)
idx_N = torch.argmin(d_no_grad, dim=1)
idx_Bhw = idx_N.view(B, ph, pw)
h_BChw = (
F.interpolate(
self.embedding(idx_Bhw).permute(0, 3, 1, 2),
size=(H, W),
mode="bicubic",
).contiguous()
if (si != SN - 1)
else self.embedding(idx_Bhw).permute(0, 3, 1, 2).contiguous()
)
h_BChw = self.quant_resi[si / (SN - 1)](h_BChw)
f_hat.add_(h_BChw)
f_rest.sub_(h_BChw)
f_hat_or_idx_Bl.append(
f_hat.clone() if to_fhat else idx_N.reshape(B, ph * pw)
)
return f_hat_or_idx_Bl
# ===================== idxBl_to_var_input: only used in VAR training, for getting teacher-forcing input =====================
def idxBl_to_var_input(self, gt_ms_idx_Bl: List[torch.Tensor]) -> torch.Tensor:
next_scales = []
B = gt_ms_idx_Bl[0].shape[0]
C = self.Cvae
H = W = self.v_patch_nums[-1]
SN = len(self.v_patch_nums)
f_hat = gt_ms_idx_Bl[0].new_zeros(B, C, H, W, dtype=torch.float32)
pn_next: int = self.v_patch_nums[0]
for si in range(SN - 1):
h_BChw = F.interpolate(
self.embedding(gt_ms_idx_Bl[si])
.transpose_(1, 2)
.view(B, C, pn_next, pn_next),
size=(H, W),
mode="bicubic",
)
f_hat.add_(self.quant_resi[si / (SN - 1)](h_BChw))
pn_next = self.v_patch_nums[si + 1]
next_scales.append(
F.interpolate(f_hat, size=(pn_next, pn_next), mode="area")
.view(B, C, -1)
.transpose(1, 2)
)
# cat BlCs to BLC, this should be float32
return torch.cat(next_scales, dim=1) if len(next_scales) else None
# ===================== get_next_autoregressive_input: only used in VAR inference, for getting next step's input =====================
def get_next_autoregressive_input(
self, si: int, SN: int, f_hat: torch.Tensor, h_BChw: torch.Tensor
) -> Tuple[Optional[torch.Tensor], torch.Tensor]: # only used in VAR inference
HW = self.v_patch_nums[-1]
if si != SN - 1:
h = self.quant_resi[si / (SN - 1)](
F.interpolate(h_BChw, size=(HW, HW), mode="bicubic")
) # conv after upsample
f_hat.add_(h)
return f_hat, F.interpolate(
f_hat,
size=(self.v_patch_nums[si + 1], self.v_patch_nums[si + 1]),
mode="area",
)
else:
h = self.quant_resi[si / (SN - 1)](h_BChw)
f_hat.add_(h)
return f_hat, f_hat
class Phi(nn.Conv2d):
def __init__(self, embed_dim, quant_resi):
ks = 3
super().__init__(
in_channels=embed_dim,
out_channels=embed_dim,
kernel_size=ks,
stride=1,
padding=ks // 2,
)
self.resi_ratio = abs(quant_resi)
def forward(self, h_BChw):
return h_BChw.mul(1 - self.resi_ratio) + super().forward(h_BChw).mul_(
self.resi_ratio
)
class PhiShared(nn.Module):
def __init__(self, qresi: Phi):
super().__init__()
self.qresi: Phi = qresi
def __getitem__(self, _) -> Phi:
return self.qresi
class PhiPartiallyShared(nn.Module):
def __init__(self, qresi_ls: nn.ModuleList):
super().__init__()
self.qresi_ls = qresi_ls
K = len(qresi_ls)
self.ticks = (
np.linspace(1 / 3 / K, 1 - 1 / 3 / K, K)
if K == 4
else np.linspace(1 / 2 / K, 1 - 1 / 2 / K, K)
)
def __getitem__(self, at_from_0_to_1: float) -> Phi:
return self.qresi_ls[np.argmin(np.abs(self.ticks - at_from_0_to_1)).item()]
def extra_repr(self) -> str:
return f"ticks={self.ticks}"
class PhiNonShared(nn.ModuleList):
def __init__(self, qresi: List):
super().__init__(qresi)
# self.qresi = qresi
K = len(qresi)
self.ticks = (
np.linspace(1 / 3 / K, 1 - 1 / 3 / K, K)
if K == 4
else np.linspace(1 / 2 / K, 1 - 1 / 2 / K, K)
)
def __getitem__(self, at_from_0_to_1: float) -> Phi:
return super().__getitem__(
np.argmin(np.abs(self.ticks - at_from_0_to_1)).item()
)
def extra_repr(self) -> str:
return f"ticks={self.ticks}"
|