import importlib import numpy as np import taming import torch import yaml from omegaconf import OmegaConf from PIL import Image from taming.models.vqgan import VQModel from utils import get_device def load_config(config_path, display=False): config = OmegaConf.load(config_path) if display: print(yaml.dump(OmegaConf.to_container(config))) return config def load_default(device): conf_path = "./model_checkpoints/vqgan_only.yaml" config = load_config(conf_path, display=False) model = taming.models.vqgan.VQModel(**config.model.params) sd = torch.load("./model_checkpoints/vqgan_only.pt", map_location=device) model.load_state_dict(sd, strict=True) model.to(device) del sd return model def load_vqgan(config, ckpt_path=None, is_gumbel=False): model = VQModel(**config.model.params) if ckpt_path is not None: sd = torch.load(ckpt_path, map_location="cpu")["state_dict"] missing, unexpected = model.load_state_dict(sd, strict=False) return model.eval() def reconstruct_with_vqgan(x, model): z, _, [_, _, indices] = model.encode(x) print(f"VQGAN --- {model.__class__.__name__}: latent shape: {z.shape[2:]}") xrec = model.decode(z) return xrec def get_obj_from_str(string, reload=False): module, cls = string.rsplit(".", 1) if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=None), cls) def instantiate_from_config(config): if "target" not in config: raise KeyError("Expected key `target` to instantiate.") return get_obj_from_str(config["target"])(**config.get("params", dict())) def load_model_from_config(config, sd, gpu=True, eval_mode=True): model = instantiate_from_config(config) if sd is not None: model.load_state_dict(sd) if gpu: model.cuda() if eval_mode: model.eval() return {"model": model} def load_model(config, ckpt, gpu, eval_mode): # load the specified checkpoint if ckpt: pl_sd = torch.load(ckpt, map_location="cpu") global_step = pl_sd["global_step"] print(f"loaded model from global step {global_step}.") else: pl_sd = {"state_dict": None} global_step = None model = load_model_from_config( config.model, pl_sd["state_dict"], gpu=gpu, eval_mode=eval_mode )["model"] return model, global_step