|
import argparse |
|
import torch |
|
|
|
from .flux_vqgan import AutoEncoder |
|
|
|
|
|
def load_cnn(model, state_dict, prefix, expand=False, use_linear=False): |
|
delete_keys = [] |
|
loaded_keys = [] |
|
for key in state_dict: |
|
if key.startswith(prefix): |
|
_key = key[len(prefix) :] |
|
if _key in model.state_dict(): |
|
|
|
if use_linear and ( |
|
".q.weight" in key |
|
or ".k.weight" in key |
|
or ".v.weight" in key |
|
or ".proj_out.weight" in key |
|
): |
|
load_weights = state_dict[key].squeeze() |
|
elif _key.endswith(".conv.weight") and expand: |
|
if model.state_dict()[_key].shape == state_dict[key].shape: |
|
|
|
load_weights = state_dict[key] |
|
else: |
|
|
|
_expand_dim = model.state_dict()[_key].shape[2] |
|
load_weights = ( |
|
state_dict[key].unsqueeze(2).repeat(1, 1, _expand_dim, 1, 1) |
|
) |
|
else: |
|
load_weights = state_dict[key] |
|
model.state_dict()[_key].copy_(load_weights) |
|
delete_keys.append(key) |
|
loaded_keys.append(prefix + _key) |
|
|
|
conv_list = ( |
|
["conv"] |
|
if use_linear |
|
else ["conv", ".q.", ".k.", ".v.", ".proj_out.", ".nin_shortcut."] |
|
) |
|
if any(k in _key for k in conv_list): |
|
if _key.endswith(".weight"): |
|
conv_key = _key.replace(".weight", ".conv.weight") |
|
if conv_key and conv_key in model.state_dict(): |
|
if model.state_dict()[conv_key].shape == state_dict[key].shape: |
|
|
|
load_weights = state_dict[key] |
|
else: |
|
|
|
_expand_dim = model.state_dict()[conv_key].shape[2] |
|
load_weights = ( |
|
state_dict[key] |
|
.unsqueeze(2) |
|
.repeat(1, 1, _expand_dim, 1, 1) |
|
) |
|
model.state_dict()[conv_key].copy_(load_weights) |
|
delete_keys.append(key) |
|
loaded_keys.append(prefix + conv_key) |
|
if _key.endswith(".bias"): |
|
conv_key = _key.replace(".bias", ".conv.bias") |
|
if conv_key and conv_key in model.state_dict(): |
|
model.state_dict()[conv_key].copy_(state_dict[key]) |
|
delete_keys.append(key) |
|
loaded_keys.append(prefix + conv_key) |
|
|
|
if "norm" in _key: |
|
if _key.endswith(".weight"): |
|
norm_key = _key.replace(".weight", ".norm.weight") |
|
if norm_key and norm_key in model.state_dict(): |
|
model.state_dict()[norm_key].copy_(state_dict[key]) |
|
delete_keys.append(key) |
|
loaded_keys.append(prefix + norm_key) |
|
if _key.endswith(".bias"): |
|
norm_key = _key.replace(".bias", ".norm.bias") |
|
if norm_key and norm_key in model.state_dict(): |
|
model.state_dict()[norm_key].copy_(state_dict[key]) |
|
delete_keys.append(key) |
|
loaded_keys.append(prefix + norm_key) |
|
|
|
for key in delete_keys: |
|
del state_dict[key] |
|
|
|
return model, state_dict, loaded_keys |
|
|
|
|
|
def vae_model( |
|
vqgan_ckpt, |
|
schedule_mode, |
|
codebook_dim, |
|
codebook_size, |
|
test_mode=True, |
|
patch_size=16, |
|
encoder_ch_mult=[1, 2, 4, 4, 4], |
|
decoder_ch_mult=[1, 2, 4, 4, 4], |
|
): |
|
args = argparse.Namespace( |
|
vqgan_ckpt=vqgan_ckpt, |
|
sd_ckpt=None, |
|
inference_type="image", |
|
save="./imagenet_val_bsq", |
|
save_prediction=True, |
|
image_recon4video=False, |
|
junke_old=False, |
|
device="cuda", |
|
max_steps=1000000.0, |
|
log_every=1, |
|
visu_every=1000, |
|
ckpt_every=1000, |
|
default_root_dir="", |
|
compile="no", |
|
ema="no", |
|
lr=0.0001, |
|
beta1=0.9, |
|
beta2=0.95, |
|
warmup_steps=0, |
|
optim_type="Adam", |
|
disc_optim_type=None, |
|
lr_min=0.0, |
|
warmup_lr_init=0.0, |
|
max_grad_norm=1.0, |
|
max_grad_norm_disc=1.0, |
|
disable_sch=False, |
|
patch_size=patch_size, |
|
temporal_patch_size=4, |
|
embedding_dim=256, |
|
codebook_dim=codebook_dim, |
|
num_quantizers=8, |
|
quantizer_type="MultiScaleBSQ", |
|
use_vae=False, |
|
use_freq_enc=False, |
|
use_freq_dec=False, |
|
preserve_norm=False, |
|
ln_before_quant=False, |
|
ln_init_by_sqrt=False, |
|
use_pxsf=False, |
|
new_quant=True, |
|
use_decay_factor=False, |
|
mask_out=False, |
|
use_stochastic_depth=False, |
|
drop_rate=0.0, |
|
schedule_mode=schedule_mode, |
|
lr_drop=None, |
|
lr_drop_rate=0.1, |
|
keep_first_quant=False, |
|
keep_last_quant=False, |
|
remove_residual_detach=False, |
|
use_out_phi=False, |
|
use_out_phi_res=False, |
|
use_lecam_reg=False, |
|
lecam_weight=0.05, |
|
perceptual_model="vgg16", |
|
base_ch_disc=64, |
|
random_flip=False, |
|
flip_prob=0.5, |
|
flip_mode="stochastic", |
|
max_flip_lvl=1, |
|
not_load_optimizer=False, |
|
use_lecam_reg_zero=False, |
|
freeze_encoder=False, |
|
rm_downsample=False, |
|
random_flip_1lvl=False, |
|
flip_lvl_idx=0, |
|
drop_when_test=False, |
|
drop_lvl_idx=0, |
|
drop_lvl_num=1, |
|
disc_version="v1", |
|
magvit_disc=False, |
|
sigmoid_in_disc=False, |
|
activation_in_disc="leaky_relu", |
|
apply_blur=False, |
|
apply_noise=False, |
|
dis_warmup_steps=0, |
|
dis_lr_multiplier=1.0, |
|
dis_minlr_multiplier=False, |
|
disc_channels=64, |
|
disc_layers=3, |
|
discriminator_iter_start=0, |
|
disc_pretrain_iter=0, |
|
disc_optim_steps=1, |
|
disc_warmup=0, |
|
disc_pool="no", |
|
disc_pool_size=1000, |
|
advanced_disc=False, |
|
recon_loss_type="l1", |
|
video_perceptual_weight=0.0, |
|
image_gan_weight=1.0, |
|
video_gan_weight=1.0, |
|
image_disc_weight=0.0, |
|
video_disc_weight=0.0, |
|
l1_weight=4.0, |
|
gan_feat_weight=0.0, |
|
perceptual_weight=0.0, |
|
kl_weight=0.0, |
|
lfq_weight=0.0, |
|
entropy_loss_weight=0.1, |
|
commitment_loss_weight=0.25, |
|
diversity_gamma=1, |
|
norm_type="group", |
|
disc_loss_type="hinge", |
|
use_checkpoint=False, |
|
precision="fp32", |
|
encoder_dtype="fp32", |
|
upcast_attention="", |
|
upcast_tf32=False, |
|
tokenizer="flux", |
|
pretrained=None, |
|
pretrained_mode="full", |
|
inflation_pe=False, |
|
init_vgen="no", |
|
no_init_idis=False, |
|
init_idis="keep", |
|
init_vdis="no", |
|
enable_nan_detector=False, |
|
turn_on_profiler=False, |
|
profiler_scheduler_wait_steps=10, |
|
debug=True, |
|
video_logger=False, |
|
bytenas="", |
|
username="", |
|
seed=1234, |
|
vq_to_vae=False, |
|
load_not_strict=False, |
|
zero=0, |
|
bucket_cap_mb=40, |
|
manual_gc_interval=1000, |
|
data_path=[""], |
|
data_type=[""], |
|
dataset_list=["imagenet"], |
|
fps=-1, |
|
dataaug="resizecrop", |
|
multi_resolution=False, |
|
random_bucket_ratio=0.0, |
|
sequence_length=16, |
|
resolution=[256, 256], |
|
batch_size=[1], |
|
num_workers=0, |
|
image_channels=3, |
|
codebook_size=codebook_size, |
|
codebook_l2_norm=True, |
|
codebook_show_usage=True, |
|
commit_loss_beta=0.25, |
|
entropy_loss_ratio=0.0, |
|
base_ch=128, |
|
num_res_blocks=2, |
|
encoder_ch_mult=encoder_ch_mult, |
|
decoder_ch_mult=decoder_ch_mult, |
|
dropout_p=0.0, |
|
cnn_type="2d", |
|
cnn_version="v1", |
|
conv_in_out_2d="no", |
|
conv_inner_2d="no", |
|
res_conv_2d="no", |
|
cnn_attention="no", |
|
cnn_norm_axis="spatial", |
|
flux_weight=0, |
|
cycle_weight=0, |
|
cycle_feat_weight=0, |
|
cycle_gan_weight=0, |
|
cycle_loop=0, |
|
z_drop=0.0, |
|
) |
|
|
|
vae = AutoEncoder(args) |
|
use_vae = vae.use_vae |
|
if not use_vae: |
|
num_codes = args.codebook_size |
|
if isinstance(vqgan_ckpt, str): |
|
state_dict = torch.load( |
|
args.vqgan_ckpt, map_location=torch.device("cpu"), weights_only=True |
|
) |
|
else: |
|
state_dict = args.vqgan_ckpt |
|
if state_dict: |
|
if args.ema == "yes": |
|
vae, new_state_dict, loaded_keys = load_cnn( |
|
vae, state_dict["ema"], prefix="", expand=False |
|
) |
|
else: |
|
vae, new_state_dict, loaded_keys = load_cnn( |
|
vae, state_dict["vae"], prefix="", expand=False |
|
) |
|
if test_mode: |
|
vae.eval() |
|
[p.requires_grad_(False) for p in vae.parameters()] |
|
return vae |
|
|