Face-editor / loaders.py
Erwann Millon
change default path
0be9cd5
raw
history blame
2.47 kB
import importlib
import numpy as np
import taming
import torch
import yaml
from omegaconf import OmegaConf
from PIL import Image
from taming.models.vqgan import VQModel
from utils import get_device
def load_config(config_path, display=False):
config = OmegaConf.load(config_path)
if display:
print(yaml.dump(OmegaConf.to_container(config)))
return config
def load_default(device):
conf_path = "./model_checkpoints/vqgan_only.yaml"
config = load_config(conf_path, display=False)
model = taming.models.vqgan.VQModel(**config.model.params)
sd = torch.load("./model_checkpoints/vqgan_only.pt", map_location=device)
model.load_state_dict(sd, strict=True)
model.to(device)
del sd
return model
def load_vqgan(config, ckpt_path=None, is_gumbel=False):
model = VQModel(**config.model.params)
if ckpt_path is not None:
sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
missing, unexpected = model.load_state_dict(sd, strict=False)
return model.eval()
def reconstruct_with_vqgan(x, model):
z, _, [_, _, indices] = model.encode(x)
print(f"VQGAN --- {model.__class__.__name__}: latent shape: {z.shape[2:]}")
xrec = model.decode(z)
return xrec
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def instantiate_from_config(config):
if "target" not in config:
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def load_model_from_config(config, sd, gpu=True, eval_mode=True):
model = instantiate_from_config(config)
if sd is not None:
model.load_state_dict(sd)
if gpu:
model.cuda()
if eval_mode:
model.eval()
return {"model": model}
def load_model(config, ckpt, gpu, eval_mode):
# load the specified checkpoint
if ckpt:
pl_sd = torch.load(ckpt, map_location="cpu")
global_step = pl_sd["global_step"]
print(f"loaded model from global step {global_step}.")
else:
pl_sd = {"state_dict": None}
global_step = None
model = load_model_from_config(
config.model, pl_sd["state_dict"], gpu=gpu, eval_mode=eval_mode
)["model"]
return model, global_step