|
import argparse |
|
import json |
|
import sys |
|
from pathlib import Path |
|
|
|
import k_diffusion |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from einops import rearrange, repeat |
|
from omegaconf import OmegaConf |
|
from PIL import Image |
|
from pytorch_lightning import seed_everything |
|
from tqdm import tqdm |
|
|
|
sys.path.append("./") |
|
sys.path.append("./stable_diffusion") |
|
|
|
from ldm.modules.attention import CrossAttention |
|
from ldm.util import instantiate_from_config |
|
from metrics.clip_similarity import ClipSimilarity |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def append_dims(x, target_dims): |
|
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.""" |
|
dims_to_append = target_dims - x.ndim |
|
if dims_to_append < 0: |
|
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") |
|
return x[(...,) + (None,) * dims_to_append] |
|
|
|
|
|
def to_d(x, sigma, denoised): |
|
"""Converts a denoiser output to a Karras ODE derivative.""" |
|
return (x - denoised) / append_dims(sigma, x.ndim) |
|
|
|
|
|
def get_ancestral_step(sigma_from, sigma_to): |
|
"""Calculates the noise level (sigma_down) to step down to and the amount |
|
of noise to add (sigma_up) when doing an ancestral sampling step.""" |
|
sigma_up = min(sigma_to, (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5) |
|
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 |
|
return sigma_down, sigma_up |
|
|
|
|
|
def sample_euler_ancestral(model, x, sigmas, prompt2prompt_threshold=0.0, **extra_args): |
|
"""Ancestral sampling with Euler method steps.""" |
|
s_in = x.new_ones([x.shape[0]]) |
|
for i in range(len(sigmas) - 1): |
|
prompt_to_prompt = prompt2prompt_threshold > i / (len(sigmas) - 2) |
|
for m in model.modules(): |
|
if isinstance(m, CrossAttention): |
|
m.prompt_to_prompt = prompt_to_prompt |
|
denoised = model(x, sigmas[i] * s_in, **extra_args) |
|
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) |
|
d = to_d(x, sigmas[i], denoised) |
|
|
|
dt = sigma_down - sigmas[i] |
|
x = x + d * dt |
|
if sigmas[i + 1] > 0: |
|
|
|
x = x + torch.randn_like(x[:1]) * sigma_up |
|
return x |
|
|
|
|
|
|
|
|
|
|
|
def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False): |
|
print(f"Loading model from {ckpt}") |
|
pl_sd = torch.load(ckpt, map_location="cpu") |
|
if "global_step" in pl_sd: |
|
print(f"Global Step: {pl_sd['global_step']}") |
|
sd = pl_sd["state_dict"] |
|
if vae_ckpt is not None: |
|
print(f"Loading VAE from {vae_ckpt}") |
|
vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"] |
|
sd = { |
|
k: vae_sd[k[len("first_stage_model.") :]] if k.startswith("first_stage_model.") else v |
|
for k, v in sd.items() |
|
} |
|
model = instantiate_from_config(config.model) |
|
m, u = model.load_state_dict(sd, strict=False) |
|
if len(m) > 0 and verbose: |
|
print("missing keys:") |
|
print(m) |
|
if len(u) > 0 and verbose: |
|
print("unexpected keys:") |
|
print(u) |
|
return model |
|
|
|
|
|
class CFGDenoiser(nn.Module): |
|
def __init__(self, model): |
|
super().__init__() |
|
self.inner_model = model |
|
|
|
def forward(self, x, sigma, uncond, cond, cfg_scale): |
|
x_in = torch.cat([x] * 2) |
|
sigma_in = torch.cat([sigma] * 2) |
|
cond_in = torch.cat([uncond, cond]) |
|
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) |
|
return uncond + (cond - uncond) * cfg_scale |
|
|
|
|
|
def to_pil(image: torch.Tensor) -> Image.Image: |
|
image = 255.0 * rearrange(image.cpu().numpy(), "c h w -> h w c") |
|
image = Image.fromarray(image.astype(np.uint8)) |
|
return image |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--out_dir", |
|
type=str, |
|
required=True, |
|
help="Path to output dataset directory.", |
|
) |
|
parser.add_argument( |
|
"--prompts_file", |
|
type=str, |
|
required=True, |
|
help="Path to prompts .jsonl file.", |
|
) |
|
parser.add_argument( |
|
"--ckpt", |
|
type=str, |
|
default="stable_diffusion/models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt", |
|
help="Path to stable diffusion checkpoint.", |
|
) |
|
parser.add_argument( |
|
"--vae-ckpt", |
|
type=str, |
|
default="stable_diffusion/models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt", |
|
help="Path to vae checkpoint.", |
|
) |
|
parser.add_argument( |
|
"--steps", |
|
type=int, |
|
default=100, |
|
help="Number of sampling steps.", |
|
) |
|
parser.add_argument( |
|
"--n-samples", |
|
type=int, |
|
default=100, |
|
help="Number of samples to generate per prompt (before CLIP filtering).", |
|
) |
|
parser.add_argument( |
|
"--max-out-samples", |
|
type=int, |
|
default=4, |
|
help="Max number of output samples to save per prompt (after CLIP filtering).", |
|
) |
|
parser.add_argument( |
|
"--n-partitions", |
|
type=int, |
|
default=1, |
|
help="Number of total partitions.", |
|
) |
|
parser.add_argument( |
|
"--partition", |
|
type=int, |
|
default=0, |
|
help="Partition index.", |
|
) |
|
parser.add_argument( |
|
"--min-p2p", |
|
type=float, |
|
default=0.1, |
|
help="Min prompt2prompt threshold (portion of denoising for which to fix self attention maps).", |
|
) |
|
parser.add_argument( |
|
"--max-p2p", |
|
type=float, |
|
default=0.9, |
|
help="Max prompt2prompt threshold (portion of denoising for which to fix self attention maps).", |
|
) |
|
parser.add_argument( |
|
"--min-cfg", |
|
type=float, |
|
default=7.5, |
|
help="Min classifier free guidance scale.", |
|
) |
|
parser.add_argument( |
|
"--max-cfg", |
|
type=float, |
|
default=15, |
|
help="Max classifier free guidance scale.", |
|
) |
|
parser.add_argument( |
|
"--clip-threshold", |
|
type=float, |
|
default=0.2, |
|
help="CLIP threshold for text-image similarity of each image.", |
|
) |
|
parser.add_argument( |
|
"--clip-dir-threshold", |
|
type=float, |
|
default=0.2, |
|
help="Directional CLIP threshold for similarity of change between pairs of text and pairs of images.", |
|
) |
|
parser.add_argument( |
|
"--clip-img-threshold", |
|
type=float, |
|
default=0.7, |
|
help="CLIP threshold for image-image similarity.", |
|
) |
|
opt = parser.parse_args() |
|
|
|
global_seed = torch.randint(1 << 32, ()).item() |
|
print(f"Global seed: {global_seed}") |
|
seed_everything(global_seed) |
|
|
|
model = load_model_from_config( |
|
OmegaConf.load("stable_diffusion/configs/stable-diffusion/v1-inference.yaml"), |
|
ckpt=opt.ckpt, |
|
vae_ckpt=opt.vae_ckpt, |
|
) |
|
model.cuda().eval() |
|
model_wrap = k_diffusion.external.CompVisDenoiser(model) |
|
|
|
clip_similarity = ClipSimilarity().cuda() |
|
|
|
out_dir = Path(opt.out_dir) |
|
out_dir.mkdir(exist_ok=True, parents=True) |
|
|
|
with open(opt.prompts_file) as fp: |
|
prompts = [json.loads(line) for line in fp] |
|
|
|
print(f"Partition index {opt.partition} ({opt.partition + 1} / {opt.n_partitions})") |
|
prompts = np.array_split(list(enumerate(prompts)), opt.n_partitions)[opt.partition] |
|
|
|
with torch.no_grad(), torch.autocast("cuda"), model.ema_scope(): |
|
uncond = model.get_learned_conditioning(2 * [""]) |
|
sigmas = model_wrap.get_sigmas(opt.steps) |
|
|
|
for i, prompt in tqdm(prompts, desc="Prompts"): |
|
prompt_dir = out_dir.joinpath(f"{i:07d}") |
|
prompt_dir.mkdir(exist_ok=True) |
|
|
|
with open(prompt_dir.joinpath("prompt.json"), "w") as fp: |
|
json.dump(prompt, fp) |
|
|
|
cond = model.get_learned_conditioning([prompt["caption"], prompt["output"]]) |
|
results = {} |
|
|
|
with tqdm(total=opt.n_samples, desc="Samples") as progress_bar: |
|
|
|
while len(results) < opt.n_samples: |
|
seed = torch.randint(1 << 32, ()).item() |
|
if seed in results: |
|
continue |
|
torch.manual_seed(seed) |
|
|
|
x = torch.randn(1, 4, 512 // 8, 512 // 8, device="cuda") * sigmas[0] |
|
x = repeat(x, "1 ... -> n ...", n=2) |
|
|
|
model_wrap_cfg = CFGDenoiser(model_wrap) |
|
p2p_threshold = opt.min_p2p + torch.rand(()).item() * (opt.max_p2p - opt.min_p2p) |
|
cfg_scale = opt.min_cfg + torch.rand(()).item() * (opt.max_cfg - opt.min_cfg) |
|
extra_args = {"cond": cond, "uncond": uncond, "cfg_scale": cfg_scale} |
|
samples_ddim = sample_euler_ancestral(model_wrap_cfg, x, sigmas, p2p_threshold, **extra_args) |
|
x_samples_ddim = model.decode_first_stage(samples_ddim) |
|
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) |
|
|
|
x0 = x_samples_ddim[0] |
|
x1 = x_samples_ddim[1] |
|
|
|
clip_sim_0, clip_sim_1, clip_sim_dir, clip_sim_image = clip_similarity( |
|
x0[None], x1[None], [prompt["caption"]], [prompt["output"]] |
|
) |
|
|
|
results[seed] = dict( |
|
image_0=to_pil(x0), |
|
image_1=to_pil(x1), |
|
p2p_threshold=p2p_threshold, |
|
cfg_scale=cfg_scale, |
|
clip_sim_0=clip_sim_0[0].item(), |
|
clip_sim_1=clip_sim_1[0].item(), |
|
clip_sim_dir=clip_sim_dir[0].item(), |
|
clip_sim_image=clip_sim_image[0].item(), |
|
) |
|
|
|
progress_bar.update() |
|
|
|
|
|
metadata = [ |
|
(result["clip_sim_dir"], seed) |
|
for seed, result in results.items() |
|
if result["clip_sim_image"] >= opt.clip_img_threshold |
|
and result["clip_sim_dir"] >= opt.clip_dir_threshold |
|
and result["clip_sim_0"] >= opt.clip_threshold |
|
and result["clip_sim_1"] >= opt.clip_threshold |
|
] |
|
metadata.sort(reverse=True) |
|
for _, seed in metadata[: opt.max_out_samples]: |
|
result = results[seed] |
|
image_0 = result.pop("image_0") |
|
image_1 = result.pop("image_1") |
|
image_0.save(prompt_dir.joinpath(f"{seed}_0.jpg"), quality=100) |
|
image_1.save(prompt_dir.joinpath(f"{seed}_1.jpg"), quality=100) |
|
with open(prompt_dir.joinpath(f"metadata.jsonl"), "a") as fp: |
|
fp.write(f"{json.dumps(dict(seed=seed, **result))}\n") |
|
|
|
print("Done.") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|