Face-editor / loaders.py
erwann's picture
wip statewrapper chagne
4e7a12f
raw
history blame
3.23 kB
import importlib
import numpy as np
import taming
import torch
import yaml
from omegaconf import OmegaConf
from PIL import Image
from taming.models.vqgan import VQModel
from utils import get_device
# import discriminator
def load_config(config_path, display=False):
config = OmegaConf.load(config_path)
if display:
print(yaml.dump(OmegaConf.to_container(config)))
return config
# def load_disc(device):
# dconf = load_config("disc_config.yaml")
# sd = torch.load("disc.pt", map_location=device)
# # print(sd.keys())
# model = discriminator.NLayerDiscriminator()
# model.load_state_dict(sd, strict=True)
# model.to(device)
# return model
# print(dconf.keys())
def load_default(device):
# device = get_device()
ckpt_path = "logs/2021-04-23T18-11-19_celebahq_transformer/checkpoints/last.ckpt"
conf_path = "./unwrapped.yaml"
config = load_config(conf_path, display=False)
model = taming.models.vqgan.VQModel(**config.model.params)
sd = torch.load("./vqgan_only.pt", map_location=device)
model.load_state_dict(sd, strict=True)
model.to(device)
del sd
return model
def load_vqgan(config, ckpt_path=None, is_gumbel=False):
if is_gumbel:
model = GumbelVQ(**config.model.params)
else:
model = VQModel(**config.model.params)
if ckpt_path is not None:
sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
missing, unexpected = model.load_state_dict(sd, strict=False)
return model.eval()
def load_ffhq():
conf = "2020-11-09T13-33-36_faceshq_vqgan/configs/2020-11-09T13-33-36-project.yaml"
ckpt = "2020-11-09T13-33-36_faceshq_vqgan/checkpoints/last.ckpt"
vqgan = load_model(load_config(conf), ckpt, True, True)[0]
def reconstruct_with_vqgan(x, model):
# could also use model(x) for reconstruction but use explicit encoding and decoding here
z, _, [_, _, indices] = model.encode(x)
print(f"VQGAN --- {model.__class__.__name__}: latent shape: {z.shape[2:]}")
xrec = model.decode(z)
return xrec
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def instantiate_from_config(config):
if not "target" in config:
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def load_model_from_config(config, sd, gpu=True, eval_mode=True):
model = instantiate_from_config(config)
if sd is not None:
model.load_state_dict(sd)
if gpu:
model.cuda()
if eval_mode:
model.eval()
return {"model": model}
def load_model(config, ckpt, gpu, eval_mode):
# load the specified checkpoint
if ckpt:
pl_sd = torch.load(ckpt, map_location="cpu")
global_step = pl_sd["global_step"]
print(f"loaded model from global step {global_step}.")
else:
pl_sd = {"state_dict": None}
global_step = None
model = load_model_from_config(config.model, pl_sd["state_dict"], gpu=gpu, eval_mode=eval_mode)["model"]
return model, global_step