PopYou / models /var.py
AmitIsraeli's picture
change mps rng
bbf15e6
raw
history blame
18.2 kB
import math
from functools import partial
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
import dist
from models.basic_var import AdaLNBeforeHead, AdaLNSelfAttn
from models.helpers import gumbel_softmax_with_rng, sample_with_top_k_top_p_
from models.vqvae import VQVAE, VectorQuantizer2
class SharedAdaLin(nn.Linear):
def forward(self, cond_BD):
C = self.weight.shape[0] // 6
return super().forward(cond_BD).view(-1, 1, 6, C) # B16C
class VAR(nn.Module):
def __init__(
self, vae_local: VQVAE,
num_classes=1000, depth=16, embed_dim=1024, num_heads=16, mlp_ratio=4., drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0.,
norm_eps=1e-6, shared_aln=False, cond_drop_rate=0.1,
attn_l2_norm=False,
patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
flash_if_available=True, fused_if_available=True,
):
super().__init__()
# 0. hyperparameters
assert embed_dim % num_heads == 0
self.Cvae, self.V = vae_local.Cvae, vae_local.vocab_size
self.depth, self.C, self.D, self.num_heads = depth, embed_dim, embed_dim, num_heads
self.cond_drop_rate = cond_drop_rate
self.prog_si = -1 # progressive training
self.patch_nums: Tuple[int] = patch_nums
self.L = sum(pn ** 2 for pn in self.patch_nums)
self.first_l = self.patch_nums[0] ** 2
self.begin_ends = []
cur = 0
for i, pn in enumerate(self.patch_nums):
self.begin_ends.append((cur, cur + pn ** 2))
cur += pn ** 2
self.num_stages_minus_1 = len(self.patch_nums) - 1
self.rng = torch.Generator(device="cpu")
# 1. input (word) embedding
quant: VectorQuantizer2 = vae_local.quantize
self.vae_proxy: Tuple[VQVAE] = (vae_local,)
self.vae_quant_proxy: Tuple[VectorQuantizer2] = (quant,)
self.word_embed = nn.Linear(self.Cvae, self.C)
# 2. class embedding
init_std = math.sqrt(1 / self.C / 3)
self.num_classes = num_classes
self.uniform_prob = torch.full((1, num_classes), fill_value=1.0 / num_classes, dtype=torch.float32,
device=dist.get_device())
self.class_emb = nn.Embedding(self.num_classes + 1, self.C)
nn.init.trunc_normal_(self.class_emb.weight.data, mean=0, std=init_std)
self.pos_start = nn.Parameter(torch.empty(1, self.first_l, self.C))
nn.init.trunc_normal_(self.pos_start.data, mean=0, std=init_std)
# 3. absolute position embedding
pos_1LC = []
for i, pn in enumerate(self.patch_nums):
pe = torch.empty(1, pn * pn, self.C)
nn.init.trunc_normal_(pe, mean=0, std=init_std)
pos_1LC.append(pe)
pos_1LC = torch.cat(pos_1LC, dim=1) # 1, L, C
assert tuple(pos_1LC.shape) == (1, self.L, self.C)
self.pos_1LC = nn.Parameter(pos_1LC)
# level embedding (similar to GPT's segment embedding, used to distinguish different levels of token pyramid)
self.lvl_embed = nn.Embedding(len(self.patch_nums), self.C)
nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std)
# 4. backbone blocks
self.shared_ada_lin = nn.Sequential(nn.SiLU(inplace=False),
SharedAdaLin(self.D, 6 * self.C)) if shared_aln else nn.Identity()
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
self.drop_path_rate = drop_path_rate
dpr = [x.item() for x in
torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule (linearly increasing)
self.blocks = nn.ModuleList([
AdaLNSelfAttn(
cond_dim=self.D, shared_aln=shared_aln,
block_idx=block_idx, embed_dim=self.C, norm_layer=norm_layer, num_heads=num_heads, mlp_ratio=mlp_ratio,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[block_idx],
last_drop_p=0 if block_idx == 0 else dpr[block_idx - 1],
attn_l2_norm=attn_l2_norm,
flash_if_available=flash_if_available, fused_if_available=fused_if_available,
)
for block_idx in range(depth)
])
fused_add_norm_fns = [b.fused_add_norm_fn is not None for b in self.blocks]
self.using_fused_add_norm_fn = any(fused_add_norm_fns)
print(
f'\n[constructor] ==== flash_if_available={flash_if_available} ({sum(b.attn.using_flash for b in self.blocks)}/{self.depth}), fused_if_available={fused_if_available} (fusing_add_ln={sum(fused_add_norm_fns)}/{self.depth}, fusing_mlp={sum(b.ffn.fused_mlp_func is not None for b in self.blocks)}/{self.depth}) ==== \n'
f' [VAR config ] embed_dim={embed_dim}, num_heads={num_heads}, depth={depth}, mlp_ratio={mlp_ratio}\n'
f' [drop ratios ] drop_rate={drop_rate}, attn_drop_rate={attn_drop_rate}, drop_path_rate={drop_path_rate:g} ({torch.linspace(0, drop_path_rate, depth)})',
end='\n\n', flush=True
)
# 5. attention mask used in training (for masking out the future)
# it won't be used in inference, since kv cache is enabled
d: torch.Tensor = torch.cat([torch.full((pn * pn,), i) for i, pn in enumerate(self.patch_nums)]).view(1, self.L,
1)
dT = d.transpose(1, 2) # dT: 11L
lvl_1L = dT[:, 0].contiguous()
self.register_buffer('lvl_1L', lvl_1L)
attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, self.L, self.L)
self.register_buffer('attn_bias_for_masking', attn_bias_for_masking.contiguous())
# 6. classifier head
self.head_nm = AdaLNBeforeHead(self.C, self.D, norm_layer=norm_layer)
self.head = nn.Linear(self.C, self.V)
def get_logits(self, h_or_h_and_residual: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
cond_BD: Optional[torch.Tensor]):
if not isinstance(h_or_h_and_residual, torch.Tensor):
h, resi = h_or_h_and_residual # fused_add_norm must be used
h = resi + self.blocks[-1].drop_path(h)
else: # fused_add_norm is not used
h = h_or_h_and_residual
return self.head(self.head_nm(h.float(), cond_BD).float()).float()
@torch.no_grad()
def autoregressive_infer_cfg(
self, B: int, label_B: Optional[Union[int, torch.LongTensor]],
delta_condition: torch.Tensor, alpha: float, beta: float,
g_seed: Optional[int] = None, cfg=1.5, top_k=0, top_p=0.0,
more_smooth=False,
) -> torch.Tensor: # returns reconstructed image (B, 3, H, W) in [0, 1]
"""
Generate images using autoregressive inference with classifier-free guidance.
:param B: batch size
:param label_B: class labels; if None, randomly sampled
:param delta_condition: tensor of shape (B, D)
:param alpha: scalar weight for class embedding
:param beta: scalar weight for delta_condition
:param g_seed: random seed
:param cfg: classifier-free guidance ratio
:param top_k: top-k sampling
:param top_p: top-p sampling
:param more_smooth: smoothing the pred using gumbel softmax; only used in visualization, not used in FID/IS benchmarking
:return: reconstructed images (B, 3, H, W)
"""
if g_seed is None:
rng = None
else:
self.rng.manual_seed(g_seed)
rng = self.rng
device = self.lvl_1L.device
if label_B is None:
label_B = torch.multinomial(self.uniform_prob, num_samples=B, replacement=True, generator=rng).reshape(B)
elif isinstance(label_B, int):
label_B = torch.full((B,), fill_value=self.num_classes if label_B < 0 else label_B, device=device)
# Prepare labels for conditioned and unconditioned versions
label_B_cond = label_B
label_B_uncond = torch.full_like(label_B, fill_value=self.num_classes)
label_B = torch.cat((label_B_cond, label_B_uncond), dim=0) # shape (2B,)
# Prepare delta_condition for conditioned and unconditioned versions
delta_condition_uncond = torch.zeros_like(delta_condition)
delta_condition = torch.cat((delta_condition, delta_condition_uncond), dim=0) # shape (2B, D)
class_emb = self.class_emb(label_B) # shape (2B, D)
cond_BD = alpha * class_emb + beta * delta_condition # shape (2B, D)
sos = cond_BD.unsqueeze(1).expand(2 * B, self.first_l, -1) + self.pos_start.expand(2 * B, self.first_l, -1)
lvl_pos = self.lvl_embed(self.lvl_1L) + self.pos_1LC
next_token_map = sos + lvl_pos[:, :self.first_l]
cur_L = 0
f_hat = sos.new_zeros(B, self.Cvae, self.patch_nums[-1], self.patch_nums[-1])
for b in self.blocks:
b.attn.kv_caching(True)
for si, pn in enumerate(self.patch_nums): # si: i-th segment
ratio = si / self.num_stages_minus_1
cur_L += pn * pn
cond_BD_or_gss = self.shared_ada_lin(cond_BD)
x = next_token_map
for b in self.blocks:
x = b(x=x, cond_BD=cond_BD_or_gss, attn_bias=None)
logits_BlV = self.get_logits(x, cond_BD)
t = cfg * ratio
logits_BlV = (1 + t) * logits_BlV[:B] - t * logits_BlV[B:]
idx_Bl = sample_with_top_k_top_p_(logits_BlV, rng=rng, top_k=top_k, top_p=top_p, num_samples=1)[:, :, 0]
if not more_smooth: # this is the default case
h_BChw = self.vae_quant_proxy[0].embedding(idx_Bl) # B, l, Cvae
else: # not used when evaluating FID/IS/Precision/Recall
gum_t = max(0.27 * (1 - ratio * 0.95), 0.005) # refer to mask-git
h_BChw = gumbel_softmax_with_rng(logits_BlV.mul(1 + ratio), tau=gum_t, hard=False, dim=-1, rng=rng) @ \
self.vae_quant_proxy[0].embedding.weight.unsqueeze(0)
h_BChw = h_BChw.transpose_(1, 2).reshape(B, self.Cvae, pn, pn)
f_hat, next_token_map = self.vae_quant_proxy[0].get_next_autoregressive_input(si, len(self.patch_nums),
f_hat, h_BChw)
if si != self.num_stages_minus_1: # prepare for next stage
next_token_map = next_token_map.view(B, self.Cvae, -1).transpose(1, 2)
next_token_map = self.word_embed(next_token_map) + lvl_pos[:,
cur_L:cur_L + self.patch_nums[si + 1] ** 2]
next_token_map = next_token_map.repeat(2, 1, 1) # double the batch sizes due to CFG
for b in self.blocks:
b.attn.kv_caching(False)
return self.vae_proxy[0].fhat_to_img(f_hat).add_(1).mul_(0.5) # de-normalize, from [-1, 1] to [0, 1]
def forward(self, label_B: torch.LongTensor, x_BLCv_wo_first_l: torch.Tensor, delta_condition: torch.Tensor,
alpha: float, beta: float) -> torch.Tensor:
"""
:param label_B: label_B
:param x_BLCv_wo_first_l: teacher forcing input (B, self.L-self.first_l, self.Cvae)
:param delta_condition: tensor of shape (B, D)
:param alpha: scalar weight for class embedding
:param beta: scalar weight for delta_condition
:return: logits BLV, V is vocab_size
"""
bg, ed = self.begin_ends[self.prog_si] if self.prog_si >= 0 else (0, self.L)
B = x_BLCv_wo_first_l.shape[0]
with torch.cuda.amp.autocast(enabled=False):
# Implement conditional dropout
drop_mask = torch.rand(B, device=label_B.device) < self.cond_drop_rate
label_B_dropped = torch.where(drop_mask, self.num_classes, label_B)
delta_condition_dropped = delta_condition.clone()
delta_condition_dropped[drop_mask] = 0.0 # Drop delta_condition
class_emb = self.class_emb(label_B_dropped)
cond_BD = alpha * class_emb + beta * delta_condition_dropped
sos = cond_BD.unsqueeze(1).expand(B, self.first_l, -1) + self.pos_start.expand(B, self.first_l, -1)
if self.prog_si == 0:
x_BLC = sos
else:
x_BLC = torch.cat((sos, self.word_embed(x_BLCv_wo_first_l.float())), dim=1)
x_BLC += self.lvl_embed(self.lvl_1L[:, :ed].expand(B, -1)) + self.pos_1LC[:, :ed] # lvl: BLC; pos: 1LC
attn_bias = self.attn_bias_for_masking[:, :, :ed, :ed]
cond_BD_or_gss = self.shared_ada_lin(cond_BD)
# hack: get the dtype if mixed precision is used
temp = x_BLC.new_ones(8, 8)
main_type = torch.matmul(temp, temp).dtype
x_BLC = x_BLC.to(dtype=main_type)
cond_BD_or_gss = cond_BD_or_gss.to(dtype=main_type)
attn_bias = attn_bias.to(dtype=main_type)
AdaLNSelfAttn.forward
for i, b in enumerate(self.blocks):
x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, attn_bias=attn_bias)
x_BLC = self.get_logits(x_BLC.float(), cond_BD)
if self.prog_si == 0:
if isinstance(self.word_embed, nn.Linear):
x_BLC[0, 0, 0] += self.word_embed.weight[0, 0] * 0 + self.word_embed.bias[0] * 0
else:
s = 0
for p in self.word_embed.parameters():
if p.requires_grad:
s += p.view(-1)[0] * 0
x_BLC[0, 0, 0] += s
return x_BLC # logits BLV, V is vocab_size
def init_weights(self, init_adaln=0.5, init_adaln_gamma=1e-5, init_head=0.02, init_std=0.02, conv_std_or_gain=0.02):
if init_std < 0: init_std = (1 / self.C / 3) ** 0.5 # init_std < 0: automated
print(f'[init_weights] {type(self).__name__} with {init_std=:g}')
for m in self.modules():
with_weight = hasattr(m, 'weight') and m.weight is not None
with_bias = hasattr(m, 'bias') and m.bias is not None
if isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight.data, std=init_std)
if with_bias: m.bias.data.zero_()
elif isinstance(m, nn.Embedding):
nn.init.trunc_normal_(m.weight.data, std=init_std)
if m.padding_idx is not None: m.weight.data[m.padding_idx].zero_()
elif isinstance(m, (
nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm, nn.GroupNorm,
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
if with_weight: m.weight.data.fill_(1.)
if with_bias: m.bias.data.zero_()
# conv: VAR has no conv, only VQVAE has conv
elif isinstance(m, (
nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
if conv_std_or_gain > 0:
nn.init.trunc_normal_(m.weight.data, std=conv_std_or_gain)
else:
nn.init.xavier_normal_(m.weight.data, gain=-conv_std_or_gain)
if with_bias: m.bias.data.zero_()
if init_head >= 0:
if isinstance(self.head, nn.Linear):
self.head.weight.data.mul_(init_head)
self.head.bias.data.zero_()
elif isinstance(self.head, nn.Sequential):
self.head[-1].weight.data.mul_(init_head)
self.head[-1].bias.data.zero_()
if isinstance(self.head_nm, AdaLNBeforeHead):
self.head_nm.ada_lin[-1].weight.data.mul_(init_adaln)
if hasattr(self.head_nm.ada_lin[-1], 'bias') and self.head_nm.ada_lin[-1].bias is not None:
self.head_nm.ada_lin[-1].bias.data.zero_()
depth = len(self.blocks)
for block_idx, sab in enumerate(self.blocks):
sab: AdaLNSelfAttn
sab.attn.proj.weight.data.div_(math.sqrt(2 * depth))
sab.ffn.fc2.weight.data.div_(math.sqrt(2 * depth))
if hasattr(sab.ffn, 'fcg') and sab.ffn.fcg is not None:
nn.init.ones_(sab.ffn.fcg.bias)
nn.init.trunc_normal_(sab.ffn.fcg.weight, std=1e-5)
if hasattr(sab, 'ada_lin'):
sab.ada_lin[-1].weight.data[2 * self.C:].mul_(init_adaln)
sab.ada_lin[-1].weight.data[:2 * self.C].mul_(init_adaln_gamma)
if hasattr(sab.ada_lin[-1], 'bias') and sab.ada_lin[-1].bias is not None:
sab.ada_lin[-1].bias.data.zero_()
elif hasattr(sab, 'ada_gss'):
sab.ada_gss.data[:, :, 2:].mul_(init_adaln)
sab.ada_gss.data[:, :, :2].mul_(init_adaln_gamma)
def extra_repr(self):
return f'drop_path_rate={self.drop_path_rate:g}'
class VARHF(VAR, PyTorchModelHubMixin):
def __init__(
self,
vae_kwargs,
num_classes=1000, depth=16, embed_dim=1024, num_heads=16, mlp_ratio=4., drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0.,
norm_eps=1e-6, shared_aln=False, cond_drop_rate=0.1,
attn_l2_norm=False,
patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
flash_if_available=True, fused_if_available=True,
):
vae_local = VQVAE(**vae_kwargs)
super().__init__(
vae_local=vae_local,
num_classes=num_classes, depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate,
norm_eps=norm_eps, shared_aln=shared_aln, cond_drop_rate=cond_drop_rate,
attn_l2_norm=attn_l2_norm,
patch_nums=patch_nums,
flash_if_available=flash_if_available, fused_if_available=fused_if_available,
)