AI_Gen_for_SG / scripts /evaluate_model.py
常舒宁
add files
1dc89cf
raw
history blame
No virus
2.72 kB
import argparse, os, sys, glob
sys.path.append(os.path.join(sys.path[0], '..'))
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from einops import rearrange
from torchvision.utils import make_grid
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.data.personalized import PersonalizedBase
from evaluation.clip_eval import LDMCLIPEvaluator
def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
model.cuda()
model.eval()
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--prompt",
type=str,
nargs="?",
default="a painting of a virus monster playing guitar",
help="the prompt to render"
)
parser.add_argument(
"--ckpt_path",
type=str,
default="/data/pretrained_models/ldm/text2img-large/model.ckpt",
help="Path to pretrained ldm text2img model")
parser.add_argument(
"--embedding_path",
type=str,
help="Path to a pre-trained embedding manager checkpoint")
parser.add_argument(
"--data_dir",
type=str,
help="Path to directory with images used to train the embedding vectors"
)
opt = parser.parse_args()
config = OmegaConf.load("configs/latent-diffusion/txt2img-1p4B-eval_with_tokens.yaml") # TODO: Optionally download from same location as ckpt and chnage this logic
model = load_model_from_config(config, opt.ckpt_path) # TODO: check path
model.embedding_manager.load(opt.embedding_path)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
evaluator = LDMCLIPEvaluator(device)
prompt = opt.prompt
data_loader = PersonalizedBase(opt.data_dir, size=256, flip_p=0.0)
images = [torch.from_numpy(data_loader[i]["image"]).permute(2, 0, 1) for i in range(data_loader.num_images)]
images = torch.stack(images, axis=0)
sim_img, sim_text = evaluator.evaluate(model, images, opt.prompt)
output_dir = os.path.join(opt.out_dir, prompt.replace(" ", "-"))
print("Image similarity: ", sim_img)
print("Text similarity: ", sim_text)