xl_fb / inversion_run_base.py
zhiweili
add pre enhance
3226a63
raw
history blame
7.87 kB
import torch
from diffusers import (
DDPMScheduler,
StableDiffusionXLImg2ImgPipeline,
)
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import retrieve_timesteps, retrieve_latents
from PIL import Image
from inversion_utils import get_ddpm_inversion_scheduler, create_xts
from config import get_config, get_num_steps_actual
from functools import partial
from compel import Compel, ReturnedEmbeddingsType
from hidiffusion import apply_hidiffusion, remove_hidiffusion
class Object(object):
pass
args = Object()
args.images_paths = None
args.images_folder = None
args.force_use_cpu = False
args.folder_name = 'test_measure_time'
args.config_from_file = 'run_configs/noise_shift_guidance_1_5.yaml'
args.save_intermediate_results = False
args.batch_size = None
args.skip_p_to_p = True
args.only_p_to_p = False
args.fp16 = False
args.prompts_file = 'dataset_measure_time/dataset.json'
args.images_in_prompts_file = None
args.seed = 986
args.time_measure_n = 1
assert (
args.batch_size is None or args.save_intermediate_results is False
), "save_intermediate_results is not implemented for batch_size > 1"
generator = None
device = "cuda" if torch.cuda.is_available() else "cpu"
BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
# BASE_MODEL = "stabilityai/sdxl-turbo"
pipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
)
pipeline = pipeline.to(device)
pipeline.scheduler = DDPMScheduler.from_pretrained(
BASE_MODEL,
subfolder="scheduler",
)
apply_hidiffusion(pipeline)
config = get_config(args)
compel_proc = Compel(
tokenizer=[pipeline.tokenizer, pipeline.tokenizer_2] ,
text_encoder=[pipeline.text_encoder, pipeline.text_encoder_2],
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
requires_pooled=[False, True]
)
def run(
input_image:Image,
src_prompt:str,
tgt_prompt:str,
generate_size:int,
seed:int,
w1:float,
w2:float,
num_steps:int,
start_step:int,
guidance_scale:float,
):
generator = torch.Generator().manual_seed(seed)
config.num_steps_inversion = num_steps
config.step_start = start_step
num_steps_actual = get_num_steps_actual(config)
num_steps_inversion = config.num_steps_inversion
denoising_start = (num_steps_inversion - num_steps_actual) / num_steps_inversion
print(f"-------->num_steps_inversion: {num_steps_inversion} num_steps_actual: {num_steps_actual} denoising_start: {denoising_start}")
timesteps, num_inference_steps = retrieve_timesteps(
pipeline.scheduler, num_steps_inversion, device, None
)
timesteps, num_inference_steps = pipeline.get_timesteps(
num_inference_steps=num_inference_steps,
denoising_start=denoising_start,
strength=0,
device=device,
)
timesteps = timesteps.type(torch.int64)
timesteps = [torch.tensor(t) for t in timesteps.tolist()]
timesteps_len = len(timesteps)
config.step_start = start_step + num_steps_actual - timesteps_len
num_steps_actual = timesteps_len
config.max_norm_zs = [-1] * (num_steps_actual - 1) + [15.5]
print(f"-------->num_steps_inversion: {num_steps_inversion} num_steps_actual: {num_steps_actual} step_start: {config.step_start}")
print(f"-------->timesteps len: {len(timesteps)} max_norm_zs len: {len(config.max_norm_zs)}")
pipeline.__call__ = partial(
pipeline.__call__,
num_inference_steps=num_steps_inversion,
guidance_scale=guidance_scale,
generator=generator,
denoising_start=denoising_start,
strength=0,
)
x_0_image = input_image
x_0 = encode_image(x_0_image, pipeline)
x_ts = create_xts(1, None, 0, generator, pipeline.scheduler, timesteps, x_0, no_add_noise=False)
x_ts = [xt.to(dtype=torch.float16) for xt in x_ts]
latents = [x_ts[0]]
x_ts_c_hat = [None]
config.ws1 = [w1] * num_steps_actual
config.ws2 = [w2] * num_steps_actual
pipeline.scheduler = get_ddpm_inversion_scheduler(
pipeline.scheduler,
config.step_function,
config,
timesteps,
config.save_timesteps,
latents,
x_ts,
x_ts_c_hat,
args.save_intermediate_results,
pipeline,
x_0,
v1s_images := [],
v2s_images := [],
deltas_images := [],
v1_x0s := [],
v2_x0s := [],
deltas_x0s := [],
"res12",
image_name="im_name",
time_measure_n=args.time_measure_n,
)
latent = latents[0].expand(3, -1, -1, -1)
prompt = [src_prompt, src_prompt, tgt_prompt]
conditioning, pooled = compel_proc(prompt)
image = pipeline.__call__(
image=latent,
prompt_embeds=conditioning,
pooled_prompt_embeds=pooled,
eta=1,
).images
return image[2]
def encode_image(image, pipe):
image = pipe.image_processor.preprocess(image)
originDtype = pipe.dtype
image = image.to(device=device, dtype=originDtype)
if pipe.vae.config.force_upcast:
image = image.float()
pipe.vae.to(dtype=torch.float32)
if isinstance(generator, list):
init_latents = [
retrieve_latents(pipe.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(1)
]
init_latents = torch.cat(init_latents, dim=0)
else:
init_latents = retrieve_latents(pipe.vae.encode(image), generator=generator)
if pipe.vae.config.force_upcast:
pipe.vae.to(originDtype)
init_latents = init_latents.to(originDtype)
init_latents = pipe.vae.config.scaling_factor * init_latents
return init_latents.to(dtype=torch.float16)
def get_timesteps(pipe, num_inference_steps, strength, device, denoising_start=None):
# get the original timestep using init_timestep
if denoising_start is None:
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
else:
t_start = 0
timesteps = pipe.scheduler.timesteps[t_start * pipe.scheduler.order :]
# Strength is irrelevant if we directly request a timestep to start at;
# that is, strength is determined by the denoising_start instead.
if denoising_start is not None:
discrete_timestep_cutoff = int(
round(
pipe.scheduler.config.num_train_timesteps
- (denoising_start * pipe.scheduler.config.num_train_timesteps)
)
)
num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
if pipe.scheduler.order == 2 and num_inference_steps % 2 == 0:
# if the scheduler is a 2nd order scheduler we might have to do +1
# because `num_inference_steps` might be even given that every timestep
# (except the highest one) is duplicated. If `num_inference_steps` is even it would
# mean that we cut the timesteps in the middle of the denoising step
# (between 1st and 2nd derivative) which leads to incorrect results. By adding 1
# we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
num_inference_steps = num_inference_steps + 1
# because t_n+1 >= t_n, we slice the timesteps starting from the end
timesteps = timesteps[-num_inference_steps:]
return timesteps, num_inference_steps
return timesteps, num_inference_steps - t_start