import importlib import numpy as np import taming import torch import yaml from omegaconf import OmegaConf from PIL import Image from taming.models.vqgan import VQModel from utils import get_device # import discriminator def load_config(config_path, display=False): config = OmegaConf.load(config_path) if display: print(yaml.dump(OmegaConf.to_container(config))) return config # def load_disc(device): # dconf = load_config("disc_config.yaml") # sd = torch.load("disc.pt", map_location=device) # # print(sd.keys()) # model = discriminator.NLayerDiscriminator() # model.load_state_dict(sd, strict=True) # model.to(device) # return model # print(dconf.keys()) def load_default(device): # device = get_device() ckpt_path = "logs/2021-04-23T18-11-19_celebahq_transformer/checkpoints/last.ckpt" conf_path = "./unwrapped.yaml" config = load_config(conf_path, display=False) model = taming.models.vqgan.VQModel(**config.model.params) sd = torch.load("./vqgan_only.pt", map_location=device) model.load_state_dict(sd, strict=True) model.to(device) return model def load_vqgan(config, ckpt_path=None, is_gumbel=False): if is_gumbel: model = GumbelVQ(**config.model.params) else: model = VQModel(**config.model.params) if ckpt_path is not None: sd = torch.load(ckpt_path, map_location="cpu")["state_dict"] missing, unexpected = model.load_state_dict(sd, strict=False) return model.eval() def load_ffhq(): conf = "2020-11-09T13-33-36_faceshq_vqgan/configs/2020-11-09T13-33-36-project.yaml" ckpt = "2020-11-09T13-33-36_faceshq_vqgan/checkpoints/last.ckpt" vqgan = load_model(load_config(conf), ckpt, True, True)[0] def reconstruct_with_vqgan(x, model): # could also use model(x) for reconstruction but use explicit encoding and decoding here z, _, [_, _, indices] = model.encode(x) print(f"VQGAN --- {model.__class__.__name__}: latent shape: {z.shape[2:]}") xrec = model.decode(z) return xrec def get_obj_from_str(string, reload=False): module, cls = string.rsplit(".", 1) if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=None), cls) def instantiate_from_config(config): if not "target" in config: raise KeyError("Expected key `target` to instantiate.") return get_obj_from_str(config["target"])(**config.get("params", dict())) def load_model_from_config(config, sd, gpu=True, eval_mode=True): model = instantiate_from_config(config) if sd is not None: model.load_state_dict(sd) if gpu: model.cuda() if eval_mode: model.eval() return {"model": model} def load_model(config, ckpt, gpu, eval_mode): # load the specified checkpoint if ckpt: pl_sd = torch.load(ckpt, map_location="cpu") global_step = pl_sd["global_step"] print(f"loaded model from global step {global_step}.") else: pl_sd = {"state_dict": None} global_step = None model = load_model_from_config(config.model, pl_sd["state_dict"], gpu=gpu, eval_mode=eval_mode)["model"] return model, global_step