import argparse, os, sys, glob sys.path.append(os.path.join(sys.path[0], '..')) import torch import numpy as np from omegaconf import OmegaConf from PIL import Image from tqdm import tqdm, trange from einops import rearrange from torchvision.utils import make_grid from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler from ldm.data.personalized import PersonalizedBase from evaluation.clip_eval import LDMCLIPEvaluator def load_model_from_config(config, ckpt, verbose=False): print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") sd = pl_sd["state_dict"] model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) if len(m) > 0 and verbose: print("missing keys:") print(m) if len(u) > 0 and verbose: print("unexpected keys:") print(u) model.cuda() model.eval() return model if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--prompt", type=str, nargs="?", default="a painting of a virus monster playing guitar", help="the prompt to render" ) parser.add_argument( "--ckpt_path", type=str, default="/data/pretrained_models/ldm/text2img-large/model.ckpt", help="Path to pretrained ldm text2img model") parser.add_argument( "--embedding_path", type=str, help="Path to a pre-trained embedding manager checkpoint") parser.add_argument( "--data_dir", type=str, help="Path to directory with images used to train the embedding vectors" ) opt = parser.parse_args() config = OmegaConf.load("configs/latent-diffusion/txt2img-1p4B-eval_with_tokens.yaml") # TODO: Optionally download from same location as ckpt and chnage this logic model = load_model_from_config(config, opt.ckpt_path) # TODO: check path model.embedding_manager.load(opt.embedding_path) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) evaluator = LDMCLIPEvaluator(device) prompt = opt.prompt data_loader = PersonalizedBase(opt.data_dir, size=256, flip_p=0.0) images = [torch.from_numpy(data_loader[i]["image"]).permute(2, 0, 1) for i in range(data_loader.num_images)] images = torch.stack(images, axis=0) sim_img, sim_text = evaluator.evaluate(model, images, opt.prompt) output_dir = os.path.join(opt.out_dir, prompt.replace(" ", "-")) print("Image similarity: ", sim_img) print("Text similarity: ", sim_text)