File size: 3,811 Bytes
79c5088
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import torch
from PIL import Image
import jsonc as json

from model import (
    DirectionalAttentionControl,
    StableDiffusionXLImg2ImgPipeline,
    register_attention_editor_diffusers,
)

from utils.pipeline_utils import *
from utils import get_args, extract_mask
from src import get_ddpm_inversion_scheduler
from visualization import save_results


def run(
    image_path,
    src_prompt,
    tgt_prompt,
    masks,
    pipeline: StableDiffusionXLImg2ImgPipeline,
    args,
):
    seed = args.seed
    num_timesteps = args.timesteps
    torch.manual_seed(seed)
    generator = torch.Generator(device=SAMPLING_DEVICE).manual_seed(seed)

    timesteps, config = set_pipeline(pipeline, num_timesteps, generator, args)

    x_0_image = Image.open(image_path).convert("RGB").resize((512, 512), RESIZE_TYPE)
    x_0 = encode_image(x_0_image, pipeline, generator)
    x_ts = create_xts(
        config.noise_shift_delta,
        config.noise_timesteps,
        generator,
        pipeline.scheduler,
        timesteps,
        x_0,
    )
    x_ts = [xt.to(dtype=x_0.dtype) for xt in x_ts]
    latents = [x_ts[0]]

    if not isinstance(masks, torch.Tensor):
        mask = extract_mask(masks, 512, 512)
    else:
        mask = masks

    pipeline.scheduler = get_ddpm_inversion_scheduler(
        pipeline.scheduler,
        config,
        timesteps,
        latents,
        x_ts,
        w1=args.w1,
        dift_timestep=args.dift_timestep,
        movement_intensifier=args.movement_intensifier,
        apply_dift_correction=args.apply_dift_correction,
        mask=mask,
    )

    
    step, layer = 0, 44
    editor = DirectionalAttentionControl(
                                        step, layer, total_steps=11, 
                                        model_type="SDXL", 
                                        alpha=args.alpha, mode=args.mode, beta=1-args.beta,
                                        structural_alignment=args.structural_alignment,
                                        support_new_object=args.support_new_object
                                        )
    register_attention_editor_diffusers(pipeline, editor)

    latent = latents[0].expand(3, -1, -1, -1)
    prompt = [src_prompt, src_prompt, tgt_prompt]
    pipeline.unet.latent_store.reset()
    image = pipeline.__call__(image=latent, prompt=prompt).images
    return [x_0_image, image[0], image[2]]


if __name__ == "__main__":
    args = get_args()

    img_paths_to_prompts = json.load(open(args.prompts_file, "r"))
    eval_dataset_folder = args.eval_dataset_folder

    img_paths = [
        f"{eval_dataset_folder}/{img_name}" for img_name in img_paths_to_prompts.keys()
    ]
    pipeline = load_pipeline(args.fp16, args.cache_dir)

    sim_scores_total = 0
    os.makedirs(args.output_dir, exist_ok=True)

    images_to_plot = []
    output_dir = args.output_dir

    for i, img_path in enumerate(img_paths):
        args.img_path = img_path
        img_name = img_path.split("/")[-1]
        prompt = img_paths_to_prompts[img_name]["src_prompt"]
        edit_prompts = img_paths_to_prompts[img_name]["tgt_prompt"]
        args.alpha = img_paths_to_prompts[img_name].get("alpha", 0.7)
        args.beta = img_paths_to_prompts[img_name].get("beta", 1)
        masks = img_paths_to_prompts[img_name].get("masks", None)
        args.mask = masks
        args.source_prompt = prompt
        args.target_prompt = edit_prompts[0]

        res = run(
            img_path,
            prompt,
            edit_prompts[0],
            masks,
            pipeline=pipeline,
            args=args,
        )
        
        torch.cuda.empty_cache()
        save_results(
            args=args,
            source_prompt=prompt,
            target_prompt=edit_prompts[0],
            images=res
            )