InstructDiffusion / edit_cli.py
TiankaiHang
sync
7ae68fe
raw
history blame
5.27 kB
# --------------------------------------------------------
# InstructDiffusion
# Based on instruct-pix2pix (https://github.com/timothybrooks/instruct-pix2pix)
# Modified by Zigang Geng (zigang@mail.ustc.edu.cn)
# --------------------------------------------------------
from __future__ import annotations
import os
import math
import random
import sys
from argparse import ArgumentParser
import einops
import k_diffusion as K
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from omegaconf import OmegaConf
from PIL import Image, ImageOps
from torch import autocast
import requests
sys.path.append("./stable_diffusion")
from stable_diffusion.ldm.util import instantiate_from_config
class CFGDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, z, sigma, cond, uncond, text_cfg_scale, image_cfg_scale):
cfg_z = einops.repeat(z, "b ... -> (repeat b) ...", repeat=3)
cfg_sigma = einops.repeat(sigma, "b ... -> (repeat b) ...", repeat=3)
cfg_cond = {
"c_crossattn": [torch.cat([cond["c_crossattn"][0], uncond["c_crossattn"][0], cond["c_crossattn"][0]])],
"c_concat": [torch.cat([cond["c_concat"][0], cond["c_concat"][0], uncond["c_concat"][0]])],
}
out_cond, out_img_cond, out_txt_cond \
= self.inner_model(cfg_z, cfg_sigma, cond=cfg_cond).chunk(3)
return 0.5 * (out_img_cond + out_txt_cond) + \
text_cfg_scale * (out_cond - out_img_cond) + \
image_cfg_scale * (out_cond - out_txt_cond)
def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False):
model = instantiate_from_config(config.model)
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
if 'state_dict' in pl_sd:
pl_sd = pl_sd['state_dict']
m, u = model.load_state_dict(pl_sd, strict=False)
print(m, u)
return model
def main():
parser = ArgumentParser()
parser.add_argument("--resolution", default=512, type=int)
parser.add_argument("--steps", default=100, type=int)
parser.add_argument("--config", default="configs/instruct_diffusion.yaml", type=str)
parser.add_argument("--ckpt", default="checkpoints/v1-5-pruned-emaonly-adaption-task.ckpt", type=str)
parser.add_argument("--vae-ckpt", default=None, type=str)
parser.add_argument("--input", required=True, type=str)
parser.add_argument("--outdir", default="logs", type=str)
parser.add_argument("--edit", required=True, type=str)
parser.add_argument("--cfg-text", default=5.0, type=float)
parser.add_argument("--cfg-image", default=1.25, type=float)
parser.add_argument("--seed", type=int)
args = parser.parse_args()
config = OmegaConf.load(args.config)
model = load_model_from_config(config, args.ckpt, args.vae_ckpt)
model.eval().cuda()
model_wrap = K.external.CompVisDenoiser(model)
model_wrap_cfg = CFGDenoiser(model_wrap)
null_token = model.get_learned_conditioning([""])
seed = random.randint(0, 100000) if args.seed is None else args.seed
if args.input.startswith("http"):
input_image = Image.open(requests.get(args.input, stream=True).raw).convert("RGB")
else:
input_image = Image.open(args.input).convert("RGB")
width, height = input_image.size
factor = args.resolution / max(width, height)
factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height)
width_resize = int((width * factor) // 64) * 64
height_resize = int((height * factor) // 64) * 64
input_image = ImageOps.fit(input_image, (width_resize, height_resize), method=Image.Resampling.LANCZOS)
output_dir = args.outdir
os.makedirs(output_dir, exist_ok=True)
with torch.no_grad(), autocast("cuda"):
cond = {}
cond["c_crossattn"] = [model.get_learned_conditioning([args.edit])]
input_image = 2 * torch.tensor(np.array(input_image)).float() / 255 - 1
input_image = rearrange(input_image, "h w c -> 1 c h w").to(next(model.parameters()).device)
cond["c_concat"] = [model.encode_first_stage(input_image).mode()]
uncond = {}
uncond["c_crossattn"] = [null_token]
uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])]
sigmas = model_wrap.get_sigmas(args.steps)
extra_args = {
"cond": cond,
"uncond": uncond,
"text_cfg_scale": args.cfg_text,
"image_cfg_scale": args.cfg_image,
}
torch.manual_seed(seed)
z = torch.randn_like(cond["c_concat"][0]) * sigmas[0]
z = K.sampling.sample_euler_ancestral(model_wrap_cfg, z, sigmas, extra_args=extra_args)
x = model.decode_first_stage(z)
x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0)
x = 255.0 * rearrange(x, "1 c h w -> h w c")
print(x.shape)
edited_image = Image.fromarray(x.type(torch.uint8).cpu().numpy())
edited_image = ImageOps.fit(edited_image, (width, height), method=Image.Resampling.LANCZOS)
edited_image.save(output_dir+'/output_'+args.input.split('/')[-1].split('.')[0]+'_seed'+str(seed)+'.jpg')
if __name__ == "__main__":
main()