Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import torch | |
from PIL import Image | |
import jsonc as json | |
from model import ( | |
DirectionalAttentionControl, | |
StableDiffusionXLImg2ImgPipeline, | |
register_attention_editor_diffusers, | |
) | |
from utils.pipeline_utils import * | |
from utils import get_args, extract_mask | |
from src import get_ddpm_inversion_scheduler | |
from visualization import save_results | |
def run( | |
image_path, | |
src_prompt, | |
tgt_prompt, | |
masks, | |
pipeline: StableDiffusionXLImg2ImgPipeline, | |
args, | |
): | |
seed = args.seed | |
num_timesteps = args.timesteps | |
torch.manual_seed(seed) | |
generator = torch.Generator(device=SAMPLING_DEVICE).manual_seed(seed) | |
timesteps, config = set_pipeline(pipeline, num_timesteps, generator, args) | |
x_0_image = Image.open(image_path).convert("RGB").resize((512, 512), RESIZE_TYPE) | |
x_0 = encode_image(x_0_image, pipeline, generator) | |
x_ts = create_xts( | |
config.noise_shift_delta, | |
config.noise_timesteps, | |
generator, | |
pipeline.scheduler, | |
timesteps, | |
x_0, | |
) | |
x_ts = [xt.to(dtype=x_0.dtype) for xt in x_ts] | |
latents = [x_ts[0]] | |
if not isinstance(masks, torch.Tensor): | |
mask = extract_mask(masks, 512, 512) | |
else: | |
mask = masks | |
pipeline.scheduler = get_ddpm_inversion_scheduler( | |
pipeline.scheduler, | |
config, | |
timesteps, | |
latents, | |
x_ts, | |
w1=args.w1, | |
dift_timestep=args.dift_timestep, | |
movement_intensifier=args.movement_intensifier, | |
apply_dift_correction=args.apply_dift_correction, | |
mask=mask, | |
) | |
step, layer = 0, 44 | |
editor = DirectionalAttentionControl( | |
step, layer, total_steps=11, | |
model_type="SDXL", | |
alpha=args.alpha, mode=args.mode, beta=1-args.beta, | |
structural_alignment=args.structural_alignment, | |
support_new_object=args.support_new_object | |
) | |
register_attention_editor_diffusers(pipeline, editor) | |
latent = latents[0].expand(3, -1, -1, -1) | |
prompt = [src_prompt, src_prompt, tgt_prompt] | |
pipeline.unet.latent_store.reset() | |
image = pipeline.__call__(image=latent, prompt=prompt).images | |
return [x_0_image, image[0], image[2]] | |
if __name__ == "__main__": | |
args = get_args() | |
img_paths_to_prompts = json.load(open(args.prompts_file, "r")) | |
eval_dataset_folder = args.eval_dataset_folder | |
img_paths = [ | |
f"{eval_dataset_folder}/{img_name}" for img_name in img_paths_to_prompts.keys() | |
] | |
pipeline = load_pipeline(args.fp16, args.cache_dir) | |
sim_scores_total = 0 | |
os.makedirs(args.output_dir, exist_ok=True) | |
images_to_plot = [] | |
output_dir = args.output_dir | |
for i, img_path in enumerate(img_paths): | |
args.img_path = img_path | |
img_name = img_path.split("/")[-1] | |
prompt = img_paths_to_prompts[img_name]["src_prompt"] | |
edit_prompts = img_paths_to_prompts[img_name]["tgt_prompt"] | |
args.alpha = img_paths_to_prompts[img_name].get("alpha", 0.7) | |
args.beta = img_paths_to_prompts[img_name].get("beta", 1) | |
masks = img_paths_to_prompts[img_name].get("masks", None) | |
args.mask = masks | |
args.source_prompt = prompt | |
args.target_prompt = edit_prompts[0] | |
res = run( | |
img_path, | |
prompt, | |
edit_prompts[0], | |
masks, | |
pipeline=pipeline, | |
args=args, | |
) | |
torch.cuda.empty_cache() | |
save_results( | |
args=args, | |
source_prompt=prompt, | |
target_prompt=edit_prompts[0], | |
images=res | |
) | |