yuvalalaluf commited on
Commit
82ef366
β€’
1 Parent(s): 991d8d3

initial commit

Browse files
appearance_transfer_model.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Callable
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from config import RunConfig
7
+ from constants import OUT_INDEX, STRUCT_INDEX, STYLE_INDEX
8
+ from models.stable_diffusion import CrossImageAttentionStableDiffusionPipeline
9
+ from utils import attention_utils
10
+ from utils.adain import masked_adain
11
+ from utils.model_utils import get_stable_diffusion_model
12
+ from utils.segmentation import Segmentor
13
+
14
+
15
+ class AppearanceTransferModel:
16
+
17
+ def __init__(self, config: RunConfig, pipe: Optional[CrossImageAttentionStableDiffusionPipeline] = None):
18
+ self.config = config
19
+ self.pipe = get_stable_diffusion_model() if pipe is None else pipe
20
+ self.register_attention_control()
21
+ self.segmentor = Segmentor(prompt=config.prompt, object_nouns=[config.object_noun])
22
+ self.latents_app, self.latents_struct = None, None
23
+ self.zs_app, self.zs_struct = None, None
24
+ self.image_app_mask_32, self.image_app_mask_64 = None, None
25
+ self.image_struct_mask_32, self.image_struct_mask_64 = None, None
26
+ self.enable_edit = False
27
+ self.step = 0
28
+
29
+ def set_latents(self, latents_app: torch.Tensor, latents_struct: torch.Tensor):
30
+ self.latents_app = latents_app
31
+ self.latents_struct = latents_struct
32
+
33
+ def set_noise(self, zs_app: torch.Tensor, zs_struct: torch.Tensor):
34
+ self.zs_app = zs_app
35
+ self.zs_struct = zs_struct
36
+
37
+ def set_masks(self, masks: List[torch.Tensor]):
38
+ self.image_app_mask_32, self.image_struct_mask_32, self.image_app_mask_64, self.image_struct_mask_64 = masks
39
+
40
+ def get_adain_callback(self):
41
+
42
+ def callback(st: int, timestep: int, latents: torch.FloatTensor) -> Callable:
43
+ self.step = st
44
+ # Compute the masks using prompt mixing self-segmentation and use the masks for AdaIN operation
45
+ if self.step == self.config.adain_range.start:
46
+ masks = self.segmentor.get_object_masks()
47
+ self.set_masks(masks)
48
+ # Apply AdaIN operation using the computed masks
49
+ if self.config.adain_range.start <= self.step < self.config.adain_range.end:
50
+ latents[0] = masked_adain(latents[0], latents[1], self.image_struct_mask_64, self.image_app_mask_64)
51
+
52
+ return callback
53
+
54
+ def register_attention_control(self):
55
+
56
+ model_self = self
57
+
58
+ class AttentionProcessor:
59
+
60
+ def __init__(self, place_in_unet: str):
61
+ self.place_in_unet = place_in_unet
62
+ if not hasattr(F, "scaled_dot_product_attention"):
63
+ raise ImportError("AttnProcessor2_0 requires torch 2.0, to use it, please upgrade torch to 2.0.")
64
+
65
+ def __call__(self,
66
+ attn,
67
+ hidden_states: torch.Tensor,
68
+ encoder_hidden_states: Optional[torch.Tensor] = None,
69
+ attention_mask=None,
70
+ temb=None,
71
+ perform_swap: bool = False):
72
+
73
+ residual = hidden_states
74
+
75
+ if attn.spatial_norm is not None:
76
+ hidden_states = attn.spatial_norm(hidden_states, temb)
77
+
78
+ input_ndim = hidden_states.ndim
79
+
80
+ if input_ndim == 4:
81
+ batch_size, channel, height, width = hidden_states.shape
82
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
83
+
84
+ batch_size, sequence_length, _ = (
85
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
86
+ )
87
+
88
+ if attention_mask is not None:
89
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
90
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
91
+
92
+ if attn.group_norm is not None:
93
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
94
+
95
+ query = attn.to_q(hidden_states)
96
+
97
+ is_cross = encoder_hidden_states is not None
98
+ if not is_cross:
99
+ encoder_hidden_states = hidden_states
100
+ elif attn.norm_cross:
101
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
102
+
103
+ key = attn.to_k(encoder_hidden_states)
104
+ value = attn.to_v(encoder_hidden_states)
105
+
106
+ inner_dim = key.shape[-1]
107
+ head_dim = inner_dim // attn.heads
108
+ should_mix = False
109
+
110
+ # Potentially apply our cross image attention operation
111
+ # To do so, we need to be in a self-attention alyer in the decoder part of the denoising network
112
+ if perform_swap and not is_cross and "up" in self.place_in_unet and model_self.enable_edit:
113
+ if attention_utils.should_mix_keys_and_values(model_self, hidden_states):
114
+ should_mix = True
115
+ if model_self.step % 5 == 0 and model_self.step < 40:
116
+ # Inject the structure's keys and values
117
+ key[OUT_INDEX] = key[STRUCT_INDEX]
118
+ value[OUT_INDEX] = value[STRUCT_INDEX]
119
+ else:
120
+ # Inject the appearance's keys and values
121
+ key[OUT_INDEX] = key[STYLE_INDEX]
122
+ value[OUT_INDEX] = value[STYLE_INDEX]
123
+
124
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
125
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
126
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
127
+
128
+ # Compute the cross attention and apply our contrasting operation
129
+ hidden_states, attn_weight = attention_utils.compute_scaled_dot_product_attention(
130
+ query, key, value,
131
+ edit_map=perform_swap and model_self.enable_edit and should_mix,
132
+ is_cross=is_cross,
133
+ contrast_strength=model_self.config.contrast_strength,
134
+ )
135
+
136
+ # Update attention map for segmentation
137
+ if model_self.config.use_masked_adain and model_self.step == model_self.config.adain_range.start - 1:
138
+ model_self.segmentor.update_attention(attn_weight, is_cross)
139
+
140
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
141
+ hidden_states = hidden_states.to(query[OUT_INDEX].dtype)
142
+
143
+ # linear proj
144
+ hidden_states = attn.to_out[0](hidden_states)
145
+ # dropout
146
+ hidden_states = attn.to_out[1](hidden_states)
147
+
148
+ if input_ndim == 4:
149
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
150
+
151
+ if attn.residual_connection:
152
+ hidden_states = hidden_states + residual
153
+
154
+ hidden_states = hidden_states / attn.rescale_output_factor
155
+
156
+ return hidden_states
157
+
158
+ def register_recr(net_, count, place_in_unet):
159
+ if net_.__class__.__name__ == 'ResnetBlock2D':
160
+ pass
161
+ if net_.__class__.__name__ == 'Attention':
162
+ net_.set_processor(AttentionProcessor(place_in_unet + f"_{count + 1}"))
163
+ return count + 1
164
+ elif hasattr(net_, 'children'):
165
+ for net__ in net_.children():
166
+ count = register_recr(net__, count, place_in_unet)
167
+ return count
168
+
169
+ cross_att_count = 0
170
+ sub_nets = self.pipe.unet.named_children()
171
+ for net in sub_nets:
172
+ if "down" in net[0]:
173
+ cross_att_count += register_recr(net[1], 0, "down")
174
+ elif "up" in net[0]:
175
+ cross_att_count += register_recr(net[1], 0, "up")
176
+ elif "mid" in net[0]:
177
+ cross_att_count += register_recr(net[1], 0, "mid")
config.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+ from typing import NamedTuple, Optional
4
+
5
+
6
+ class Range(NamedTuple):
7
+ start: int
8
+ end: int
9
+
10
+
11
+ @dataclass
12
+ class RunConfig:
13
+ # Appearance image path
14
+ app_image_path: Path
15
+ # Struct image path
16
+ struct_image_path: Path
17
+ # Domain name (e.g., buildings, animals)
18
+ domain_name: Optional[str] = None
19
+ # Output path
20
+ output_path: Path = Path('./output')
21
+ # Random seed
22
+ seed: int = 42
23
+ # Input prompt for inversion (will use domain name as default)
24
+ prompt: Optional[str] = None
25
+ # Number of timesteps
26
+ num_timesteps: int = 100
27
+ # Whether to use a binary mask for performing AdaIN
28
+ use_masked_adain: bool = True
29
+ # Timesteps to apply cross-attention on 64x64 layers
30
+ cross_attn_64_range: Range = Range(start=10, end=90)
31
+ # Timesteps to apply cross-attention on 32x32 layers
32
+ cross_attn_32_range: Range = Range(start=10, end=70)
33
+ # Timesteps to apply AdaIn
34
+ adain_range: Range = Range(start=20, end=100)
35
+ # Guidance scale
36
+ guidance_scale: float = 7.5
37
+ # Swap guidance scale
38
+ swap_guidance_scale: float = 3.5
39
+ # Attention contrasting strength
40
+ contrast_strength: float = 1.67
41
+ # Object nouns to use for self-segmentation (will use the domain name as default)
42
+ object_noun: Optional[str] = None
43
+ # Whether to load previously saved inverted latent codes
44
+ load_latents: bool = True
45
+ # Number of steps to skip in the denoising process (used value from original edit-friendly DDPM paper)
46
+ skip_steps: int = 32
47
+
48
+ def __post_init__(self):
49
+ self.output_path = self.output_path / self.domain_name
50
+ self.output_path.mkdir(parents=True, exist_ok=True)
51
+
52
+ # Handle the domain name, prompt, and object nouns used for masking, etc.
53
+ if self.use_masked_adain and self.domain_name is None:
54
+ raise ValueError("Must provide --domain_name and --prompt when using masked AdaIN")
55
+ if not self.use_masked_adain and self.domain_name is None:
56
+ self.domain_name = "object"
57
+ if self.prompt is None:
58
+ self.prompt = f"A photo of a {self.domain_name}"
59
+ if self.object_noun is None:
60
+ self.object_noun = self.domain_name
61
+
62
+ # Define the paths to store the inverted latents to
63
+ self.latents_path = Path(self.output_path) / "latents"
64
+ self.latents_path.mkdir(parents=True, exist_ok=True)
65
+ self.app_latent_save_path = self.latents_path / f"{self.app_image_path.stem}.pt"
66
+ self.struct_latent_save_path = self.latents_path / f"{self.struct_image_path.stem}.pt"
constants.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ OUT_INDEX = 0
2
+ STYLE_INDEX = 1
3
+ STRUCT_INDEX = 2
demo.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ import gradio as gr
6
+ from PIL import Image
7
+
8
+ from appearance_transfer_model import AppearanceTransferModel
9
+ from run import run_appearance_transfer
10
+ from utils.latent_utils import load_latents_or_invert_images
11
+ from utils.model_utils import get_stable_diffusion_model
12
+
13
+ sys.path.append(".")
14
+ sys.path.append("..")
15
+
16
+ from config import RunConfig
17
+
18
+ DESCRIPTION = '''
19
+ <h1 style="text-align: center;"> Cross-Image Attention for Zero-Shot Appearance Transfer </h1>
20
+ <p style="text-align: center;">
21
+ This is a demo for our <a href="https://arxiv.org/abs/2311.03335">paper</a>:
22
+ ''Cross-Image Attention for Zero-Shot Appearance Transfer''.
23
+ <br>
24
+ Given two images depicting a source structure and a target appearance, our method generates an image merging
25
+ the structure of one image with the appearance of the other.
26
+ <br>
27
+ We do so in a zero-shot manner, with no optimization or model training required while supporting appearance
28
+ transfer across images that may differ in size and shape.
29
+ </p>
30
+ '''
31
+
32
+ pipe = get_stable_diffusion_model()
33
+
34
+
35
+ def main_pipeline(app_image_path: str,
36
+ struct_image_path: str,
37
+ domain_name: str,
38
+ seed: int,
39
+ prompt: Optional[str] = None) -> Image.Image:
40
+ if prompt == "":
41
+ prompt = None
42
+ config = RunConfig(
43
+ app_image_path=Path(app_image_path),
44
+ struct_image_path=Path(struct_image_path),
45
+ domain_name=domain_name,
46
+ prompt=prompt,
47
+ seed=seed,
48
+ load_latents=False
49
+ )
50
+ model = AppearanceTransferModel(config=config, pipe=pipe)
51
+ latents_app, latents_struct, noise_app, noise_struct = load_latents_or_invert_images(model=model, cfg=config)
52
+ model.set_latents(latents_app, latents_struct)
53
+ model.set_noise(noise_app, noise_struct)
54
+ print("Running appearance transfer...")
55
+ images = run_appearance_transfer(model=model, cfg=config)
56
+ print("Done.")
57
+ return [images[0]]
58
+
59
+
60
+ with gr.Blocks(css='style.css') as demo:
61
+ gr.Markdown(DESCRIPTION)
62
+
63
+ gr.HTML('''<a href="https://huggingface.co/spaces/yuvalalaluf/cross-image-attention?duplicate=true"><img src="https://bit.ly/3gLdBN6"
64
+ alt="Duplicate Space"></a>''')
65
+
66
+ with gr.Row():
67
+ with gr.Column():
68
+ app_image_path = gr.Image(label="Upload appearance image", type="filepath")
69
+ struct_image_path = gr.Image(label="Upload structure image", type="filepath")
70
+ domain_name = gr.Text(label="Domain name", max_lines=1,
71
+ info="Specifies the domain the objects are coming from (e.g., 'animal', 'building', etc).")
72
+ prompt = gr.Text(label="Prompt to use for inversion.", value='',
73
+ info='If this kept empty, we will use the domain name to define '
74
+ 'the prompt as "A photo of a <domain_name>".')
75
+ random_seed = gr.Number(value=42, label="Random seed", precision=0)
76
+ run_button = gr.Button('Generate')
77
+
78
+ with gr.Column():
79
+ result = gr.Gallery(label='Result')
80
+ inputs = [app_image_path, struct_image_path, domain_name, random_seed, prompt]
81
+ outputs = [result]
82
+ run_button.click(fn=main_pipeline, inputs=inputs, outputs=outputs)
83
+
84
+ with gr.Row():
85
+ examples = [
86
+ ['inputs/zebra.png', 'inputs/giraffe.png', 'animal', 20, None],
87
+ ['inputs/taj_mahal.jpg', 'inputs/duomo.png', 'building', 42, None],
88
+ ['inputs/red_velvet_cake.jpg', 'inputs/chocolate_cake.jpg', 'cake', 42, 'A photo of cake'],
89
+ ]
90
+ gr.Examples(examples=examples,
91
+ inputs=[app_image_path, struct_image_path, domain_name, random_seed, prompt],
92
+ outputs=[result],
93
+ fn=main_pipeline,
94
+ cache_examples=True)
95
+
96
+ demo.launch(share=False, server_name="127.0.0.1", server_port=8888)
environment/environment.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ name: cross_image
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ dependencies:
6
+ - python=3.8.5
7
+ - pip=20.3
8
+ - cudatoolkit=11.3
9
+ - pip:
10
+ - -r requirements.txt
environment/requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ matplotlib==3.6.3
2
+ matplotlib-inline==0.1.6
3
+ jupyter==1.0.0
4
+ numpy==1.24.1
5
+ pyrallis==0.3.1
6
+ torch==2.0.1
7
+ torchvision==0.15.2
8
+ diffusers==0.19.3
9
+ transformers==4.30.2
10
+ accelerate==0.20.3
11
+ huggingface-hub==0.16.4
12
+ xformers==0.0.21
13
+ tokenizers==0.13.3
14
+ nltk==3.8.1
15
+ Pillow==10.1.0
16
+ scikit_learn==1.3.0
17
+ tqdm==4.64.1
inputs/chocolate_cake.jpg ADDED
inputs/duomo.png ADDED
inputs/giraffe.png ADDED
inputs/red_velvet_cake.jpg ADDED
inputs/taj_mahal.jpg ADDED
inputs/zebra.png ADDED
models/__init__.py ADDED
File without changes
models/stable_diffusion.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Callable, Dict, List, Optional, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ from diffusers import StableDiffusionPipeline
6
+ from diffusers.models import AutoencoderKL
7
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
8
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg
9
+ from diffusers.schedulers import KarrasDiffusionSchedulers
10
+ from tqdm import tqdm
11
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor
12
+
13
+ from config import Range
14
+ from models.unet_2d_condition import FreeUUNet2DConditionModel
15
+
16
+
17
+ class CrossImageAttentionStableDiffusionPipeline(StableDiffusionPipeline):
18
+ """ A modification of the standard StableDiffusionPipeline to incorporate our cross-image attention."""
19
+
20
+ def __init__(self, vae: AutoencoderKL,
21
+ text_encoder: CLIPTextModel,
22
+ tokenizer: CLIPTokenizer,
23
+ unet: FreeUUNet2DConditionModel,
24
+ scheduler: KarrasDiffusionSchedulers,
25
+ safety_checker: StableDiffusionSafetyChecker,
26
+ feature_extractor: CLIPImageProcessor,
27
+ requires_safety_checker: bool = True):
28
+ super().__init__(
29
+ vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker
30
+ )
31
+
32
+ @torch.no_grad()
33
+ def __call__(
34
+ self,
35
+ prompt: Union[str, List[str]] = None,
36
+ height: Optional[int] = None,
37
+ width: Optional[int] = None,
38
+ num_inference_steps: int = 50,
39
+ guidance_scale: float = 7.5,
40
+ negative_prompt: Optional[Union[str, List[str]]] = None,
41
+ num_images_per_prompt: Optional[int] = 1,
42
+ eta: float = 0.0,
43
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
44
+ latents: Optional[torch.FloatTensor] = None,
45
+ prompt_embeds: Optional[torch.FloatTensor] = None,
46
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
47
+ output_type: Optional[str] = "pil",
48
+ return_dict: bool = True,
49
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
50
+ callback_steps: int = 1,
51
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
52
+ guidance_rescale: float = 0.0,
53
+ swap_guidance_scale: float = 1.0,
54
+ cross_image_attention_range: Range = Range(10, 90),
55
+ # DDPM addition
56
+ zs: Optional[List[torch.Tensor]] = None
57
+ ):
58
+
59
+ # 0. Default height and width to unet
60
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
61
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
62
+
63
+ # 1. Check inputs. Raise error if not correct
64
+ self.check_inputs(
65
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
66
+ )
67
+
68
+ # 2. Define call parameters
69
+ if prompt is not None and isinstance(prompt, str):
70
+ batch_size = 1
71
+ elif prompt is not None and isinstance(prompt, list):
72
+ batch_size = len(prompt)
73
+ else:
74
+ batch_size = prompt_embeds.shape[0]
75
+
76
+ device = self._execution_device
77
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
78
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
79
+ # corresponds to doing no classifier free guidance.
80
+ do_classifier_free_guidance = guidance_scale > 1.0
81
+
82
+ # 3. Encode input prompt
83
+ text_encoder_lora_scale = (
84
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
85
+ )
86
+ prompt_embeds = self._encode_prompt(
87
+ prompt,
88
+ device,
89
+ num_images_per_prompt,
90
+ do_classifier_free_guidance,
91
+ negative_prompt,
92
+ prompt_embeds=prompt_embeds,
93
+ negative_prompt_embeds=negative_prompt_embeds,
94
+ lora_scale=text_encoder_lora_scale,
95
+ )
96
+
97
+ # 4. Prepare timesteps
98
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
99
+ timesteps = self.scheduler.timesteps
100
+ t_to_idx = {int(v): k for k, v in enumerate(timesteps[-zs[0].shape[0]:])}
101
+ timesteps = timesteps[-zs[0].shape[0]:]
102
+
103
+ # 5. Prepare latent variables
104
+ num_channels_latents = self.unet.config.in_channels
105
+ latents = self.prepare_latents(
106
+ batch_size * num_images_per_prompt,
107
+ num_channels_latents,
108
+ height,
109
+ width,
110
+ prompt_embeds.dtype,
111
+ device,
112
+ generator,
113
+ latents,
114
+ )
115
+
116
+ # 7. Denoising loop
117
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
118
+
119
+ op = tqdm(timesteps[-zs[0].shape[0]:])
120
+ n_timesteps = len(timesteps[-zs[0].shape[0]:])
121
+
122
+ count = 0
123
+ for t in op:
124
+ i = t_to_idx[int(t)]
125
+
126
+ # expand the latents if we are doing classifier free guidance
127
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
128
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
129
+
130
+ noise_pred_swap = self.unet(
131
+ latent_model_input,
132
+ t,
133
+ encoder_hidden_states=prompt_embeds,
134
+ cross_attention_kwargs={'perform_swap': True},
135
+ return_dict=False,
136
+ )[0]
137
+ noise_pred_no_swap = self.unet(
138
+ latent_model_input,
139
+ t,
140
+ encoder_hidden_states=prompt_embeds,
141
+ cross_attention_kwargs={'perform_swap': False},
142
+ return_dict=False,
143
+ )[0]
144
+
145
+ # perform guidance
146
+ if do_classifier_free_guidance:
147
+ _, noise_swap_pred_text = noise_pred_swap.chunk(2)
148
+ noise_no_swap_pred_uncond, _ = noise_pred_no_swap.chunk(2)
149
+ noise_pred = noise_no_swap_pred_uncond + guidance_scale * (
150
+ noise_swap_pred_text - noise_no_swap_pred_uncond)
151
+ else:
152
+ is_cross_image_step = cross_image_attention_range.start <= i <= cross_image_attention_range.end
153
+ if swap_guidance_scale > 1.0 and is_cross_image_step:
154
+ swapping_strengths = np.linspace(swap_guidance_scale,
155
+ max(swap_guidance_scale / 2, 1.0),
156
+ n_timesteps)
157
+ swapping_strength = swapping_strengths[count]
158
+ noise_pred = noise_pred_no_swap + swapping_strength * (noise_pred_swap - noise_pred_no_swap)
159
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_swap, guidance_rescale=guidance_rescale)
160
+ else:
161
+ noise_pred = noise_pred_swap
162
+
163
+ latents = torch.stack([
164
+ self.perform_ddpm_step(t_to_idx, zs[latent_idx], latents[latent_idx], t, noise_pred[latent_idx], eta)
165
+ for latent_idx in range(latents.shape[0])
166
+ ])
167
+
168
+ # call the callback, if provided
169
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
170
+ # progress_bar.update()
171
+ if callback is not None and i % callback_steps == 0:
172
+ callback(i, t, latents)
173
+
174
+ count += 1
175
+
176
+ if not output_type == "latent":
177
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
178
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
179
+ else:
180
+ image = latents
181
+ has_nsfw_concept = None
182
+
183
+ if has_nsfw_concept is None:
184
+ do_denormalize = [True] * image.shape[0]
185
+ else:
186
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
187
+
188
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
189
+
190
+ # Offload last model to CPU
191
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
192
+ self.final_offload_hook.offload()
193
+
194
+ if not return_dict:
195
+ return (image, has_nsfw_concept)
196
+
197
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
198
+
199
+ def perform_ddpm_step(self, t_to_idx, zs, latents, t, noise_pred, eta):
200
+ idx = t_to_idx[int(t)]
201
+ z = zs[idx] if not zs is None else None
202
+ # 1. get previous step value (=t-1)
203
+ prev_timestep = t - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
204
+ # 2. compute alphas, betas
205
+ alpha_prod_t = self.scheduler.alphas_cumprod[t]
206
+ alpha_prod_t_prev = self.scheduler.alphas_cumprod[
207
+ prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
208
+ beta_prod_t = 1 - alpha_prod_t
209
+ # 3. compute predicted original sample from predicted noise also called
210
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
211
+ pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5)
212
+ # 5. compute variance: "sigma_t(Ξ·)" -> see formula (16)
213
+ # Οƒ_t = sqrt((1 βˆ’ Ξ±_tβˆ’1)/(1 βˆ’ Ξ±_t)) * sqrt(1 βˆ’ Ξ±_t/Ξ±_tβˆ’1)
214
+ # variance = self.scheduler._get_variance(timestep, prev_timestep)
215
+ variance = self.get_variance(t)
216
+ std_dev_t = eta * variance ** (0.5)
217
+ # Take care of asymetric reverse process (asyrp)
218
+ model_output_direction = noise_pred
219
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
220
+ # pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output_direction
221
+ pred_sample_direction = (1 - alpha_prod_t_prev - eta * variance) ** (0.5) * model_output_direction
222
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
223
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
224
+ # 8. Add noice if eta > 0
225
+ if eta > 0:
226
+ if z is None:
227
+ z = torch.randn(noise_pred.shape, device=self.device)
228
+ sigma_z = eta * variance ** (0.5) * z
229
+ prev_sample = prev_sample + sigma_z
230
+ return prev_sample
231
+
232
+ def get_variance(self, timestep):
233
+ prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
234
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
235
+ alpha_prod_t_prev = self.scheduler.alphas_cumprod[
236
+ prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
237
+ beta_prod_t = 1 - alpha_prod_t
238
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
239
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
240
+ return variance
models/unet_2d_condition.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.utils.checkpoint
5
+ from diffusers import UNet2DConditionModel
6
+ from diffusers.models.unet_2d_condition import UNet2DConditionOutput
7
+ from diffusers.utils import logging
8
+ from torch.fft import fftn, ifftn, fftshift, ifftshift
9
+
10
+ """
11
+ This is a small extension of the standard UNet2DConditionModel with the small addition of the
12
+ Free-U trick (https://github.com/ChenyangSi/FreeU).
13
+ """
14
+
15
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
16
+
17
+
18
+ def Fourier_filter(x, threshold, scale):
19
+ # FFT
20
+ x_freq = fftn(x, dim=(-2, -1))
21
+ x_freq = fftshift(x_freq, dim=(-2, -1))
22
+
23
+ B, C, H, W = x_freq.shape
24
+ mask = torch.ones((B, C, H, W)).cuda() # CUDA için
25
+
26
+ crow, ccol = H // 2, W // 2
27
+ mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
28
+ x_freq = x_freq * mask
29
+
30
+ # IFFT
31
+ x_freq = ifftshift(x_freq, dim=(-2, -1))
32
+ x_filtered = ifftn(x_freq, dim=(-2, -1)).real
33
+
34
+ return x_filtered
35
+
36
+
37
+ class FreeUUNet2DConditionModel(UNet2DConditionModel):
38
+
39
+ def forward(
40
+ self,
41
+ sample: torch.FloatTensor,
42
+ timestep: Union[torch.Tensor, float, int],
43
+ encoder_hidden_states: torch.Tensor,
44
+ class_labels: Optional[torch.Tensor] = None,
45
+ timestep_cond: Optional[torch.Tensor] = None,
46
+ attention_mask: Optional[torch.Tensor] = None,
47
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
48
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
49
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
50
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
51
+ encoder_attention_mask: Optional[torch.Tensor] = None,
52
+ return_dict: bool = True,
53
+ ) -> Union[UNet2DConditionOutput, Tuple]:
54
+ r"""
55
+ The [`UNet2DConditionModel`] forward method.
56
+
57
+ Args:
58
+ sample (`torch.FloatTensor`):
59
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
60
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
61
+ encoder_hidden_states (`torch.FloatTensor`):
62
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
63
+ encoder_attention_mask (`torch.Tensor`):
64
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
65
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
66
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
67
+ return_dict (`bool`, *optional*, defaults to `True`):
68
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
69
+ tuple.
70
+ cross_attention_kwargs (`dict`, *optional*):
71
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
72
+ added_cond_kwargs: (`dict`, *optional*):
73
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
74
+ are passed along to the UNet blocks.
75
+
76
+ Returns:
77
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
78
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
79
+ a `tuple` is returned where the first element is the sample tensor.
80
+ """
81
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
82
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
83
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
84
+ # on the fly if necessary.
85
+ default_overall_up_factor = 2 ** self.num_upsamplers
86
+
87
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
88
+ forward_upsample_size = False
89
+ upsample_size = None
90
+
91
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
92
+ logger.info("Forward upsample size to force interpolation output size.")
93
+ forward_upsample_size = True
94
+
95
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
96
+ # expects mask of shape:
97
+ # [batch, key_tokens]
98
+ # adds singleton query_tokens dimension:
99
+ # [batch, 1, key_tokens]
100
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
101
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
102
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
103
+ if attention_mask is not None:
104
+ # assume that mask is expressed as:
105
+ # (1 = keep, 0 = discard)
106
+ # convert mask into a bias that can be added to attention scores:
107
+ # (keep = +0, discard = -10000.0)
108
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
109
+ attention_mask = attention_mask.unsqueeze(1)
110
+
111
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
112
+ if encoder_attention_mask is not None:
113
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
114
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
115
+
116
+ # 0. center input if necessary
117
+ if self.config.center_input_sample:
118
+ sample = 2 * sample - 1.0
119
+
120
+ # 1. time
121
+ timesteps = timestep
122
+ if not torch.is_tensor(timesteps):
123
+ # This would be a good case for the `match` statement (Python 3.10+)
124
+ is_mps = sample.device.type == "mps"
125
+ if isinstance(timestep, float):
126
+ dtype = torch.float32 if is_mps else torch.float64
127
+ else:
128
+ dtype = torch.int32 if is_mps else torch.int64
129
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
130
+ elif len(timesteps.shape) == 0:
131
+ timesteps = timesteps[None].to(sample.device)
132
+
133
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
134
+ timesteps = timesteps.expand(sample.shape[0])
135
+
136
+ t_emb = self.time_proj(timesteps)
137
+
138
+ # `Timesteps` does not contain any weights and will always return f32 tensors
139
+ # but time_embedding might actually be running in fp16. so we need to cast here.
140
+ # there might be better ways to encapsulate this.
141
+ t_emb = t_emb.to(dtype=sample.dtype)
142
+
143
+ emb = self.time_embedding(t_emb, timestep_cond)
144
+ aug_emb = None
145
+
146
+ if self.class_embedding is not None:
147
+ if class_labels is None:
148
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
149
+
150
+ if self.config.class_embed_type == "timestep":
151
+ class_labels = self.time_proj(class_labels)
152
+
153
+ # `Timesteps` does not contain any weights and will always return f32 tensors
154
+ # there might be better ways to encapsulate this.
155
+ class_labels = class_labels.to(dtype=sample.dtype)
156
+
157
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
158
+
159
+ if self.config.class_embeddings_concat:
160
+ emb = torch.cat([emb, class_emb], dim=-1)
161
+ else:
162
+ emb = emb + class_emb
163
+
164
+ if self.config.addition_embed_type == "text":
165
+ aug_emb = self.add_embedding(encoder_hidden_states)
166
+ elif self.config.addition_embed_type == "text_image":
167
+ # Kandinsky 2.1 - style
168
+ if "image_embeds" not in added_cond_kwargs:
169
+ raise ValueError(
170
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
171
+ )
172
+
173
+ image_embs = added_cond_kwargs.get("image_embeds")
174
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
175
+ aug_emb = self.add_embedding(text_embs, image_embs)
176
+ elif self.config.addition_embed_type == "text_time":
177
+ # SDXL - style
178
+ if "text_embeds" not in added_cond_kwargs:
179
+ raise ValueError(
180
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
181
+ )
182
+ text_embeds = added_cond_kwargs.get("text_embeds")
183
+ if "time_ids" not in added_cond_kwargs:
184
+ raise ValueError(
185
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
186
+ )
187
+ time_ids = added_cond_kwargs.get("time_ids")
188
+ time_embeds = self.add_time_proj(time_ids.flatten())
189
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
190
+
191
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
192
+ add_embeds = add_embeds.to(emb.dtype)
193
+ aug_emb = self.add_embedding(add_embeds)
194
+ elif self.config.addition_embed_type == "image":
195
+ # Kandinsky 2.2 - style
196
+ if "image_embeds" not in added_cond_kwargs:
197
+ raise ValueError(
198
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
199
+ )
200
+ image_embs = added_cond_kwargs.get("image_embeds")
201
+ aug_emb = self.add_embedding(image_embs)
202
+ elif self.config.addition_embed_type == "image_hint":
203
+ # Kandinsky 2.2 - style
204
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
205
+ raise ValueError(
206
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
207
+ )
208
+ image_embs = added_cond_kwargs.get("image_embeds")
209
+ hint = added_cond_kwargs.get("hint")
210
+ aug_emb, hint = self.add_embedding(image_embs, hint)
211
+ sample = torch.cat([sample, hint], dim=1)
212
+
213
+ emb = emb + aug_emb if aug_emb is not None else emb
214
+
215
+ if self.time_embed_act is not None:
216
+ emb = self.time_embed_act(emb)
217
+
218
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
219
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
220
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
221
+ # Kadinsky 2.1 - style
222
+ if "image_embeds" not in added_cond_kwargs:
223
+ raise ValueError(
224
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
225
+ )
226
+
227
+ image_embeds = added_cond_kwargs.get("image_embeds")
228
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
229
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
230
+ # Kandinsky 2.2 - style
231
+ if "image_embeds" not in added_cond_kwargs:
232
+ raise ValueError(
233
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
234
+ )
235
+ image_embeds = added_cond_kwargs.get("image_embeds")
236
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
237
+ # 2. pre-process
238
+ sample = self.conv_in(sample)
239
+
240
+ # 3. down
241
+
242
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
243
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
244
+
245
+ down_block_res_samples = (sample,)
246
+ for downsample_block in self.down_blocks:
247
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
248
+ # For t2i-adapter CrossAttnDownBlock2D
249
+ additional_residuals = {}
250
+ if is_adapter and len(down_block_additional_residuals) > 0:
251
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
252
+
253
+ sample, res_samples = downsample_block(
254
+ hidden_states=sample,
255
+ temb=emb,
256
+ encoder_hidden_states=encoder_hidden_states,
257
+ attention_mask=attention_mask,
258
+ cross_attention_kwargs=cross_attention_kwargs,
259
+ encoder_attention_mask=encoder_attention_mask,
260
+ **additional_residuals,
261
+ )
262
+ else:
263
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
264
+
265
+ if is_adapter and len(down_block_additional_residuals) > 0:
266
+ sample += down_block_additional_residuals.pop(0)
267
+
268
+ down_block_res_samples += res_samples
269
+
270
+ if is_controlnet:
271
+ new_down_block_res_samples = ()
272
+
273
+ for down_block_res_sample, down_block_additional_residual in zip(
274
+ down_block_res_samples, down_block_additional_residuals
275
+ ):
276
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
277
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
278
+
279
+ down_block_res_samples = new_down_block_res_samples
280
+
281
+ # 4. mid
282
+ if self.mid_block is not None:
283
+ sample = self.mid_block(
284
+ sample,
285
+ emb,
286
+ encoder_hidden_states=encoder_hidden_states,
287
+ attention_mask=attention_mask,
288
+ cross_attention_kwargs=cross_attention_kwargs,
289
+ encoder_attention_mask=encoder_attention_mask,
290
+ )
291
+
292
+ if is_controlnet:
293
+ sample = sample + mid_block_additional_residual
294
+
295
+ # 5. up
296
+ for i, upsample_block in enumerate(self.up_blocks):
297
+ is_final_block = i == len(self.up_blocks) - 1
298
+
299
+ res_samples = down_block_res_samples[-len(upsample_block.resnets):]
300
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
301
+
302
+ # Add the Free-U trick here!
303
+ # Fourier Filter
304
+ if sample.shape[1] == 1280:
305
+ sample[:, :640] *= 1.2 # 1.1 # For SD2.1
306
+ sample = Fourier_filter(sample, threshold=1, scale=0.9)
307
+
308
+ if sample.shape[1] == 640:
309
+ sample[:, :320] *= 1.4 # 1.2 # For SD2.1
310
+ sample = Fourier_filter(sample, threshold=1, scale=0.2)
311
+
312
+ # if we have not reached the final block and need to forward the
313
+ # upsample size, we do it here
314
+ if not is_final_block and forward_upsample_size:
315
+ upsample_size = down_block_res_samples[-1].shape[2:]
316
+
317
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
318
+ sample = upsample_block(
319
+ hidden_states=sample,
320
+ temb=emb,
321
+ res_hidden_states_tuple=res_samples,
322
+ encoder_hidden_states=encoder_hidden_states,
323
+ cross_attention_kwargs=cross_attention_kwargs,
324
+ upsample_size=upsample_size,
325
+ attention_mask=attention_mask,
326
+ encoder_attention_mask=encoder_attention_mask,
327
+ )
328
+ else:
329
+ sample = upsample_block(
330
+ hidden_states=sample,
331
+ temb=emb,
332
+ res_hidden_states_tuple=res_samples,
333
+ upsample_size=upsample_size
334
+ )
335
+
336
+ # 6. post-process
337
+ if self.conv_norm_out:
338
+ sample = self.conv_norm_out(sample)
339
+ sample = self.conv_act(sample)
340
+ sample = self.conv_out(sample)
341
+
342
+ if not return_dict:
343
+ return (sample,)
344
+
345
+ return UNet2DConditionOutput(sample=sample)
utils/__init__.py ADDED
File without changes
utils/adain.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def masked_adain(content_feat, style_feat, content_mask, style_mask):
2
+ assert (content_feat.size()[:2] == style_feat.size()[:2])
3
+ size = content_feat.size()
4
+ style_mean, style_std = calc_mean_std(style_feat, mask=style_mask)
5
+ content_mean, content_std = calc_mean_std(content_feat, mask=content_mask)
6
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
7
+ style_normalized_feat = normalized_feat * style_std.expand(size) + style_mean.expand(size)
8
+ return content_feat * (1 - content_mask) + style_normalized_feat * content_mask
9
+
10
+
11
+ def calc_mean_std(feat, eps=1e-5, mask=None):
12
+ # eps is a small value added to the variance to avoid divide-by-zero.
13
+ size = feat.size()
14
+ if len(size) == 2:
15
+ return calc_mean_std_2d(feat, eps, mask)
16
+
17
+ assert (len(size) == 3)
18
+ C = size[0]
19
+ if mask is not None:
20
+ feat_var = feat.view(C, -1)[:, mask.view(-1) == 1].var(dim=1) + eps
21
+ feat_std = feat_var.sqrt().view(C, 1, 1)
22
+ feat_mean = feat.view(C, -1)[:, mask.view(-1) == 1].mean(dim=1).view(C, 1, 1)
23
+ else:
24
+ feat_var = feat.view(C, -1).var(dim=1) + eps
25
+ feat_std = feat_var.sqrt().view(C, 1, 1)
26
+ feat_mean = feat.view(C, -1).mean(dim=1).view(C, 1, 1)
27
+
28
+ return feat_mean, feat_std
29
+
30
+
31
+ def calc_mean_std_2d(feat, eps=1e-5, mask=None):
32
+ # eps is a small value added to the variance to avoid divide-by-zero.
33
+ size = feat.size()
34
+ assert (len(size) == 2)
35
+ C = size[0]
36
+ if mask is not None:
37
+ feat_var = feat.view(C, -1)[:, mask.view(-1) == 1].var(dim=1) + eps
38
+ feat_std = feat_var.sqrt().view(C, 1)
39
+ feat_mean = feat.view(C, -1)[:, mask.view(-1) == 1].mean(dim=1).view(C, 1)
40
+ else:
41
+ feat_var = feat.view(C, -1).var(dim=1) + eps
42
+ feat_std = feat_var.sqrt().view(C, 1)
43
+ feat_mean = feat.view(C, -1).mean(dim=1).view(C, 1)
44
+
45
+ return feat_mean, feat_std
utils/attention_utils.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+
4
+ from constants import OUT_INDEX
5
+
6
+
7
+ def should_mix_keys_and_values(model, hidden_states: torch.Tensor) -> bool:
8
+ """ Verify whether we should perform the mixing in the current timestep. """
9
+ is_in_32_timestep_range = (
10
+ model.config.cross_attn_32_range.start <= model.step < model.config.cross_attn_32_range.end
11
+ )
12
+ is_in_64_timestep_range = (
13
+ model.config.cross_attn_64_range.start <= model.step < model.config.cross_attn_64_range.end
14
+ )
15
+ is_hidden_states_32_square = (hidden_states.shape[1] == 32 ** 2)
16
+ is_hidden_states_64_square = (hidden_states.shape[1] == 64 ** 2)
17
+ should_mix = (is_in_32_timestep_range and is_hidden_states_32_square) or \
18
+ (is_in_64_timestep_range and is_hidden_states_64_square)
19
+ return should_mix
20
+
21
+
22
+ def compute_scaled_dot_product_attention(Q, K, V, edit_map=False, is_cross=False, contrast_strength=1.0):
23
+ """ Compute the scale dot product attention, potentially with our contrasting operation. """
24
+ attn_weight = torch.softmax((Q @ K.transpose(-2, -1) / math.sqrt(Q.size(-1))), dim=-1)
25
+ if edit_map and not is_cross:
26
+ attn_weight[OUT_INDEX] = torch.stack([
27
+ torch.clip(enhance_tensor(attn_weight[OUT_INDEX][head_idx], contrast_factor=contrast_strength),
28
+ min=0.0, max=1.0)
29
+ for head_idx in range(attn_weight.shape[1])
30
+ ])
31
+ return attn_weight @ V, attn_weight
32
+
33
+
34
+ def enhance_tensor(tensor: torch.Tensor, contrast_factor: float = 1.67) -> torch.Tensor:
35
+ """ Compute the attention map contrasting. """
36
+ adjusted_tensor = (tensor - tensor.mean(dim=-1)) * contrast_factor + tensor.mean(dim=-1)
37
+ return adjusted_tensor
utils/ddpm_inversion.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+
3
+ import torch
4
+ from torch import inference_mode
5
+ from tqdm import tqdm
6
+
7
+ """
8
+ Inversion code taken from:
9
+ 1. The official implementation of Edit-Friendly DDPM Inversion: https://github.com/inbarhub/DDPM_inversion
10
+ 2. The LEDITS demo: https://huggingface.co/spaces/editing-images/ledits/tree/main
11
+ """
12
+
13
+ LOW_RESOURCE = True
14
+
15
+
16
+ def invert(x0, pipe, prompt_src="", num_diffusion_steps=100, cfg_scale_src=3.5, eta=1):
17
+ # inverts a real image according to Algorihm 1 in https://arxiv.org/pdf/2304.06140.pdf,
18
+ # based on the code in https://github.com/inbarhub/DDPM_inversion
19
+ # returns wt, zs, wts:
20
+ # wt - inverted latent
21
+ # wts - intermediate inverted latents
22
+ # zs - noise maps
23
+ pipe.scheduler.set_timesteps(num_diffusion_steps)
24
+ with inference_mode():
25
+ w0 = (pipe.vae.encode(x0).latent_dist.mode() * 0.18215).float()
26
+ wt, zs, wts = inversion_forward_process(pipe, w0, etas=eta, prompt=prompt_src, cfg_scale=cfg_scale_src,
27
+ prog_bar=True, num_inference_steps=num_diffusion_steps)
28
+ return zs, wts
29
+
30
+
31
+ def inversion_forward_process(model, x0,
32
+ etas=None,
33
+ prog_bar=False,
34
+ prompt="",
35
+ cfg_scale=3.5,
36
+ num_inference_steps=50, eps=None
37
+ ):
38
+ if not prompt == "":
39
+ text_embeddings = encode_text(model, prompt)
40
+ uncond_embedding = encode_text(model, "")
41
+ timesteps = model.scheduler.timesteps.to(model.device)
42
+ variance_noise_shape = (
43
+ num_inference_steps,
44
+ model.unet.in_channels,
45
+ model.unet.sample_size,
46
+ model.unet.sample_size)
47
+ if etas is None or (type(etas) in [int, float] and etas == 0):
48
+ eta_is_zero = True
49
+ zs = None
50
+ else:
51
+ eta_is_zero = False
52
+ if type(etas) in [int, float]: etas = [etas] * model.scheduler.num_inference_steps
53
+ xts = sample_xts_from_x0(model, x0, num_inference_steps=num_inference_steps)
54
+ alpha_bar = model.scheduler.alphas_cumprod
55
+ zs = torch.zeros(size=variance_noise_shape, device=model.device)
56
+
57
+ t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
58
+ xt = x0
59
+ op = tqdm(reversed(timesteps)) if prog_bar else reversed(timesteps)
60
+
61
+ for t in op:
62
+ idx = t_to_idx[int(t)]
63
+ # 1. predict noise residual
64
+ if not eta_is_zero:
65
+ xt = xts[idx][None]
66
+
67
+ with torch.no_grad():
68
+ out = model.unet.forward(xt, timestep=t, encoder_hidden_states=uncond_embedding)
69
+ if not prompt == "":
70
+ cond_out = model.unet.forward(xt, timestep=t, encoder_hidden_states=text_embeddings)
71
+
72
+ if not prompt == "":
73
+ ## classifier free guidance
74
+ noise_pred = out.sample + cfg_scale * (cond_out.sample - out.sample)
75
+ else:
76
+ noise_pred = out.sample
77
+
78
+ if eta_is_zero:
79
+ # 2. compute more noisy image and set x_t -> x_t+1
80
+ xt = forward_step(model, noise_pred, t, xt)
81
+
82
+ else:
83
+ xtm1 = xts[idx + 1][None]
84
+ # pred of x0
85
+ pred_original_sample = (xt - (1 - alpha_bar[t]) ** 0.5 * noise_pred) / alpha_bar[t] ** 0.5
86
+
87
+ # direction to xt
88
+ prev_timestep = t - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
89
+ alpha_prod_t_prev = model.scheduler.alphas_cumprod[
90
+ prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
91
+
92
+ variance = get_variance(model, t)
93
+ pred_sample_direction = (1 - alpha_prod_t_prev - etas[idx] * variance) ** (0.5) * noise_pred
94
+
95
+ mu_xt = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
96
+
97
+ z = (xtm1 - mu_xt) / (etas[idx] * variance ** 0.5)
98
+ zs[idx] = z
99
+
100
+ # correction to avoid error accumulation
101
+ xtm1 = mu_xt + (etas[idx] * variance ** 0.5) * z
102
+ xts[idx + 1] = xtm1
103
+
104
+ if not zs is None:
105
+ zs[-1] = torch.zeros_like(zs[-1])
106
+
107
+ return xt, zs, xts
108
+
109
+
110
+ def encode_text(model, prompts):
111
+ text_input = model.tokenizer(
112
+ prompts,
113
+ padding="max_length",
114
+ max_length=model.tokenizer.model_max_length,
115
+ truncation=True,
116
+ return_tensors="pt",
117
+ )
118
+ with torch.no_grad():
119
+ text_encoding = model.text_encoder(text_input.input_ids.to(model.device))[0]
120
+ return text_encoding
121
+
122
+
123
+ def sample_xts_from_x0(model, x0, num_inference_steps=50):
124
+ """
125
+ Samples from P(x_1:T|x_0)
126
+ """
127
+ # torch.manual_seed(43256465436)
128
+ alpha_bar = model.scheduler.alphas_cumprod
129
+ sqrt_one_minus_alpha_bar = (1 - alpha_bar) ** 0.5
130
+ alphas = model.scheduler.alphas
131
+ betas = 1 - alphas
132
+ variance_noise_shape = (
133
+ num_inference_steps,
134
+ model.unet.in_channels,
135
+ model.unet.sample_size,
136
+ model.unet.sample_size)
137
+
138
+ timesteps = model.scheduler.timesteps.to(model.device)
139
+ t_to_idx = {int(v): k for k, v in enumerate(timesteps)}
140
+ xts = torch.zeros(variance_noise_shape).to(x0.device)
141
+ for t in reversed(timesteps):
142
+ idx = t_to_idx[int(t)]
143
+ xts[idx] = x0 * (alpha_bar[t] ** 0.5) + torch.randn_like(x0) * sqrt_one_minus_alpha_bar[t]
144
+ xts = torch.cat([xts, x0], dim=0)
145
+
146
+ return xts
147
+
148
+
149
+ def forward_step(model, model_output, timestep, sample):
150
+ next_timestep = min(model.scheduler.config.num_train_timesteps - 2,
151
+ timestep + model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps)
152
+
153
+ # 2. compute alphas, betas
154
+ alpha_prod_t = model.scheduler.alphas_cumprod[timestep]
155
+
156
+ beta_prod_t = 1 - alpha_prod_t
157
+
158
+ # 3. compute predicted original sample from predicted noise also called
159
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
160
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
161
+ next_sample = model.scheduler.add_noise(pred_original_sample,
162
+ model_output,
163
+ torch.LongTensor([next_timestep]))
164
+ return next_sample
165
+
166
+
167
+ def get_variance(model, timestep):
168
+ prev_timestep = timestep - model.scheduler.config.num_train_timesteps // model.scheduler.num_inference_steps
169
+ alpha_prod_t = model.scheduler.alphas_cumprod[timestep]
170
+ alpha_prod_t_prev = model.scheduler.alphas_cumprod[
171
+ prev_timestep] if prev_timestep >= 0 else model.scheduler.final_alpha_cumprod
172
+ beta_prod_t = 1 - alpha_prod_t
173
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
174
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
175
+ return variance
176
+
177
+
178
+ class AttentionControl(abc.ABC):
179
+
180
+ def step_callback(self, x_t):
181
+ return x_t
182
+
183
+ def between_steps(self):
184
+ return
185
+
186
+ @property
187
+ def num_uncond_att_layers(self):
188
+ return self.num_att_layers if LOW_RESOURCE else 0
189
+
190
+ @abc.abstractmethod
191
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
192
+ raise NotImplementedError
193
+
194
+ def __call__(self, attn, is_cross: bool, place_in_unet: str):
195
+ if self.cur_att_layer >= self.num_uncond_att_layers:
196
+ if LOW_RESOURCE:
197
+ attn = self.forward(attn, is_cross, place_in_unet)
198
+ else:
199
+ h = attn.shape[0]
200
+ attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
201
+ self.cur_att_layer += 1
202
+ if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
203
+ self.cur_att_layer = 0
204
+ self.cur_step += 1
205
+ self.between_steps()
206
+ return attn
207
+
208
+ def reset(self):
209
+ self.cur_step = 0
210
+ self.cur_att_layer = 0
211
+
212
+ def __init__(self):
213
+ self.cur_step = 0
214
+ self.num_att_layers = -1
215
+ self.cur_att_layer = 0
216
+
217
+
218
+ class AttentionStore(AttentionControl):
219
+
220
+ @staticmethod
221
+ def get_empty_store():
222
+ return {"down_cross": [], "mid_cross": [], "up_cross": [],
223
+ "down_self": [], "mid_self": [], "up_self": []}
224
+
225
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
226
+ key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
227
+ if attn.shape[1] <= 32 ** 2: # avoid memory overhead
228
+ self.step_store[key].append(attn)
229
+ return attn
230
+
231
+ def between_steps(self):
232
+ if len(self.attention_store) == 0:
233
+ self.attention_store = self.step_store
234
+ else:
235
+ for key in self.attention_store:
236
+ for i in range(len(self.attention_store[key])):
237
+ self.attention_store[key][i] += self.step_store[key][i]
238
+ self.step_store = self.get_empty_store()
239
+
240
+ def get_average_attention(self):
241
+ average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in
242
+ self.attention_store}
243
+ return average_attention
244
+
245
+ def reset(self):
246
+ super(AttentionStore, self).reset()
247
+ self.step_store = self.get_empty_store()
248
+ self.attention_store = {}
249
+
250
+ def __init__(self):
251
+ super(AttentionStore, self).__init__()
252
+ self.step_store = self.get_empty_store()
253
+ self.attention_store = {}
254
+
255
+
256
+ def register_attention_control(model, controller):
257
+ def ca_forward(self, place_in_unet):
258
+ to_out = self.to_out
259
+ if type(to_out) is torch.nn.modules.container.ModuleList:
260
+ to_out = self.to_out[0]
261
+ else:
262
+ to_out = self.to_out
263
+
264
+ def forward(x, context=None, mask=None):
265
+ batch_size, sequence_length, dim = x.shape
266
+ h = self.heads
267
+ q = self.to_q(x)
268
+ is_cross = context is not None
269
+ context = context if is_cross else x
270
+ k = self.to_k(context)
271
+ v = self.to_v(context)
272
+ q = self.reshape_heads_to_batch_dim(q)
273
+ k = self.reshape_heads_to_batch_dim(k)
274
+ v = self.reshape_heads_to_batch_dim(v)
275
+
276
+ sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
277
+
278
+ if mask is not None:
279
+ mask = mask.reshape(batch_size, -1)
280
+ max_neg_value = -torch.finfo(sim.dtype).max
281
+ mask = mask[:, None, :].repeat(h, 1, 1)
282
+ sim.masked_fill_(~mask, max_neg_value)
283
+
284
+ # attention, what we cannot get enough of
285
+ attn = sim.softmax(dim=-1)
286
+ attn = controller(attn, is_cross, place_in_unet)
287
+ out = torch.einsum("b i j, b j d -> b i d", attn, v)
288
+ out = self.reshape_batch_dim_to_heads(out)
289
+ return to_out(out)
290
+
291
+ return forward
292
+
293
+ class DummyController:
294
+
295
+ def __call__(self, *args):
296
+ return args[0]
297
+
298
+ def __init__(self):
299
+ self.num_att_layers = 0
300
+
301
+ if controller is None:
302
+ controller = DummyController()
303
+
304
+ def register_recr(net_, count, place_in_unet):
305
+ if net_.__class__.__name__ == 'CrossAttention':
306
+ net_.forward = ca_forward(net_, place_in_unet)
307
+ return count + 1
308
+ elif hasattr(net_, 'children'):
309
+ for net__ in net_.children():
310
+ count = register_recr(net__, count, place_in_unet)
311
+ return count
312
+
313
+ cross_att_count = 0
314
+ sub_nets = model.unet.named_children()
315
+ for net in sub_nets:
316
+ if "down" in net[0]:
317
+ cross_att_count += register_recr(net[1], 0, "down")
318
+ elif "up" in net[0]:
319
+ cross_att_count += register_recr(net[1], 0, "up")
320
+ elif "mid" in net[0]:
321
+ cross_att_count += register_recr(net[1], 0, "mid")
322
+
323
+ controller.num_att_layers = cross_att_count
utils/image_utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pathlib
2
+ from typing import Optional, Tuple
3
+
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+ from config import RunConfig
8
+
9
+
10
+ def load_images(cfg: RunConfig, save_path: Optional[pathlib.Path] = None) -> Tuple[Image.Image, Image.Image]:
11
+ image_style = load_size(cfg.app_image_path)
12
+ image_struct = load_size(cfg.struct_image_path)
13
+ if save_path is not None:
14
+ Image.fromarray(image_style).save(save_path / f"in_style.png")
15
+ Image.fromarray(image_struct).save(save_path / f"in_struct.png")
16
+ return image_style, image_struct
17
+
18
+
19
+ def load_size(image_path: pathlib.Path,
20
+ left: int = 0,
21
+ right: int = 0,
22
+ top: int = 0,
23
+ bottom: int = 0,
24
+ size: int = 512) -> Image.Image:
25
+ if type(image_path) is str or type(image_path) is pathlib.PosixPath:
26
+ image = np.array(Image.open(image_path).convert('RGB'))
27
+ else:
28
+ image = image_path
29
+
30
+ h, w, c = image.shape
31
+
32
+ left = min(left, w - 1)
33
+ right = min(right, w - left - 1)
34
+ top = min(top, h - left - 1)
35
+ bottom = min(bottom, h - top - 1)
36
+ image = image[top:h - bottom, left:w - right]
37
+
38
+ h, w, c = image.shape
39
+
40
+ if h < w:
41
+ offset = (w - h) // 2
42
+ image = image[:, offset:offset + h]
43
+ elif w < h:
44
+ offset = (h - w) // 2
45
+ image = image[offset:offset + w]
46
+
47
+ image = np.array(Image.fromarray(image).resize((size, size)))
48
+ return image
49
+
50
+
51
+ def save_generated_masks(model, cfg: RunConfig):
52
+ tensor2im(model.image_app_mask_32).save(cfg.output_path / f"mask_style_32.png")
53
+ tensor2im(model.image_struct_mask_32).save(cfg.output_path / f"mask_struct_32.png")
54
+ tensor2im(model.image_app_mask_64).save(cfg.output_path / f"mask_style_64.png")
55
+ tensor2im(model.image_struct_mask_64).save(cfg.output_path / f"mask_struct_64.png")
56
+
57
+
58
+ def tensor2im(x) -> Image.Image:
59
+ return Image.fromarray(x.cpu().numpy().astype(np.uint8) * 255)
utils/latent_utils.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Tuple
3
+
4
+ import numpy as np
5
+ import torch
6
+ from PIL import Image
7
+
8
+ from appearance_transfer_model import AppearanceTransferModel
9
+ from config import RunConfig
10
+ from utils import image_utils
11
+ from utils.ddpm_inversion import invert
12
+
13
+
14
+ def load_latents_or_invert_images(model: AppearanceTransferModel, cfg: RunConfig):
15
+ if cfg.load_latents and cfg.app_latent_save_path.exists() and cfg.struct_latent_save_path.exists():
16
+ print("Loading existing latents...")
17
+ latents_app, latents_struct = load_latents(cfg.app_latent_save_path, cfg.struct_latent_save_path)
18
+ noise_app, noise_struct = load_noise(cfg.app_latent_save_path, cfg.struct_latent_save_path)
19
+ print("Done.")
20
+ else:
21
+ print("Inverting images...")
22
+ app_image, struct_image = image_utils.load_images(cfg=cfg, save_path=cfg.output_path)
23
+ model.enable_edit = False # Deactivate the cross-image attention layers
24
+ latents_app, latents_struct, noise_app, noise_struct = invert_images(app_image=app_image,
25
+ struct_image=struct_image,
26
+ sd_model=model.pipe,
27
+ cfg=cfg)
28
+ model.enable_edit = True
29
+ print("Done.")
30
+ return latents_app, latents_struct, noise_app, noise_struct
31
+
32
+
33
+ def load_latents(app_latent_save_path: Path, struct_latent_save_path: Path) -> Tuple[torch.Tensor, torch.Tensor]:
34
+ latents_app = torch.load(app_latent_save_path)
35
+ latents_struct = torch.load(struct_latent_save_path)
36
+ if type(latents_struct) == list:
37
+ latents_app = [l.to("cuda") for l in latents_app]
38
+ latents_struct = [l.to("cuda") for l in latents_struct]
39
+ else:
40
+ latents_app = latents_app.to("cuda")
41
+ latents_struct = latents_struct.to("cuda")
42
+ return latents_app, latents_struct
43
+
44
+
45
+ def load_noise(app_latent_save_path: Path, struct_latent_save_path: Path) -> Tuple[torch.Tensor, torch.Tensor]:
46
+ latents_app = torch.load(app_latent_save_path.parent / (app_latent_save_path.stem + "_ddpm_noise.pt"))
47
+ latents_struct = torch.load(struct_latent_save_path.parent / (struct_latent_save_path.stem + "_ddpm_noise.pt"))
48
+ latents_app = latents_app.to("cuda")
49
+ latents_struct = latents_struct.to("cuda")
50
+ return latents_app, latents_struct
51
+
52
+
53
+ def invert_images(sd_model: AppearanceTransferModel, app_image: Image.Image, struct_image: Image.Image, cfg: RunConfig):
54
+ input_app = torch.from_numpy(np.array(app_image)).float() / 127.5 - 1.0
55
+ input_struct = torch.from_numpy(np.array(struct_image)).float() / 127.5 - 1.0
56
+ zs_app, latents_app = invert(x0=input_app.permute(2, 0, 1).unsqueeze(0).to('cuda'),
57
+ pipe=sd_model,
58
+ prompt_src=cfg.prompt,
59
+ num_diffusion_steps=cfg.num_timesteps,
60
+ cfg_scale_src=3.5)
61
+ zs_struct, latents_struct = invert(x0=input_struct.permute(2, 0, 1).unsqueeze(0).to('cuda'),
62
+ pipe=sd_model,
63
+ prompt_src=cfg.prompt,
64
+ num_diffusion_steps=cfg.num_timesteps,
65
+ cfg_scale_src=3.5)
66
+ # Save the inverted latents and noises
67
+ torch.save(latents_app, cfg.latents_path / f"{cfg.app_image_path.stem}.pt")
68
+ torch.save(latents_struct, cfg.latents_path / f"{cfg.struct_image_path.stem}.pt")
69
+ torch.save(zs_app, cfg.latents_path / f"{cfg.app_image_path.stem}_ddpm_noise.pt")
70
+ torch.save(zs_struct, cfg.latents_path / f"{cfg.struct_image_path.stem}_ddpm_noise.pt")
71
+ return latents_app, latents_struct, zs_app, zs_struct
72
+
73
+
74
+ def get_init_latents_and_noises(model: AppearanceTransferModel, cfg: RunConfig) -> Tuple[torch.Tensor, torch.Tensor]:
75
+ # If we stored all the latents along the diffusion process, select the desired one based on the skip_steps
76
+ if model.latents_struct.dim() == 4 and model.latents_app.dim() == 4 and model.latents_app.shape[0] > 1:
77
+ model.latents_struct = model.latents_struct[cfg.skip_steps]
78
+ model.latents_app = model.latents_app[cfg.skip_steps]
79
+ init_latents = torch.stack([model.latents_struct, model.latents_app, model.latents_struct])
80
+ init_zs = [model.zs_struct[cfg.skip_steps:], model.zs_app[cfg.skip_steps:], model.zs_struct[cfg.skip_steps:]]
81
+ return init_latents, init_zs
utils/model_utils.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import DDIMScheduler
3
+
4
+ from models.stable_diffusion import CrossImageAttentionStableDiffusionPipeline
5
+ from models.unet_2d_condition import FreeUUNet2DConditionModel
6
+
7
+
8
+ def get_stable_diffusion_model() -> CrossImageAttentionStableDiffusionPipeline:
9
+ print("Loading Stable Diffusion model...")
10
+ device = torch.device(f'cuda') if torch.cuda.is_available() else torch.device('cpu')
11
+ pipe = CrossImageAttentionStableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5",
12
+ safety_checker=None).to(device)
13
+ pipe.unet = FreeUUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet").to(device)
14
+ pipe.scheduler = DDIMScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
15
+ print("Done.")
16
+ return pipe
utils/segmentation.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, List
2
+
3
+ import nltk
4
+ import numpy as np
5
+ import torch
6
+ from sklearn.cluster import KMeans
7
+
8
+ from constants import STYLE_INDEX, STRUCT_INDEX
9
+
10
+ nltk.download('punkt')
11
+ nltk.download('averaged_perceptron_tagger')
12
+
13
+ """
14
+ Self-segmentation technique taken from Prompt Mixing: https://github.com/orpatashnik/local-prompt-mixing
15
+ """
16
+
17
+ class Segmentor:
18
+
19
+ def __init__(self, prompt: str, object_nouns: List[str], num_segments: int = 5, res: int = 32):
20
+ self.prompt = prompt
21
+ self.num_segments = num_segments
22
+ self.resolution = res
23
+ self.object_nouns = object_nouns
24
+ tokenized_prompt = nltk.word_tokenize(prompt)
25
+ forbidden_words = [word.upper() for word in ["photo", "image", "picture"]]
26
+ self.nouns = [(i, word) for (i, (word, pos)) in enumerate(nltk.pos_tag(tokenized_prompt))
27
+ if pos[:2] == 'NN' and word.upper() not in forbidden_words]
28
+
29
+ def update_attention(self, attn, is_cross):
30
+ res = int(attn.shape[2] ** 0.5)
31
+ if is_cross:
32
+ if res == 16:
33
+ self.cross_attention_32 = attn
34
+ elif res == 32:
35
+ self.cross_attention_64 = attn
36
+ else:
37
+ if res == 32:
38
+ self.self_attention_32 = attn
39
+ elif res == 64:
40
+ self.self_attention_64 = attn
41
+
42
+ def __call__(self, *args, **kwargs):
43
+ clusters = self.cluster()
44
+ cluster2noun = self.cluster2noun(clusters)
45
+ return cluster2noun
46
+
47
+ def cluster(self, res: int = 32):
48
+ np.random.seed(1)
49
+ self_attn = self.self_attention_32 if res == 32 else self.self_attention_64
50
+
51
+ style_attn = self_attn[STYLE_INDEX].mean(dim=0).cpu().numpy()
52
+ style_kmeans = KMeans(n_clusters=self.num_segments, n_init=10).fit(style_attn)
53
+ style_clusters = style_kmeans.labels_.reshape(res, res)
54
+
55
+ struct_attn = self_attn[STRUCT_INDEX].mean(dim=0).cpu().numpy()
56
+ struct_kmeans = KMeans(n_clusters=self.num_segments, n_init=10).fit(struct_attn)
57
+ struct_clusters = struct_kmeans.labels_.reshape(res, res)
58
+
59
+ return style_clusters, struct_clusters
60
+
61
+ def cluster2noun(self, clusters, cross_attn, attn_index):
62
+ result = {}
63
+ res = int(cross_attn.shape[2] ** 0.5)
64
+ nouns_indices = [index for (index, word) in self.nouns]
65
+ cross_attn = cross_attn[attn_index].mean(dim=0).reshape(res, res, -1)
66
+ nouns_maps = cross_attn.cpu().numpy()[:, :, [i + 1 for i in nouns_indices]]
67
+ normalized_nouns_maps = np.zeros_like(nouns_maps).repeat(2, axis=0).repeat(2, axis=1)
68
+ for i in range(nouns_maps.shape[-1]):
69
+ curr_noun_map = nouns_maps[:, :, i].repeat(2, axis=0).repeat(2, axis=1)
70
+ normalized_nouns_maps[:, :, i] = (curr_noun_map - np.abs(curr_noun_map.min())) / curr_noun_map.max()
71
+
72
+ max_score = 0
73
+ all_scores = []
74
+ for c in range(self.num_segments):
75
+ cluster_mask = np.zeros_like(clusters)
76
+ cluster_mask[clusters == c] = 1
77
+ score_maps = [cluster_mask * normalized_nouns_maps[:, :, i] for i in range(len(nouns_indices))]
78
+ scores = [score_map.sum() / cluster_mask.sum() for score_map in score_maps]
79
+ all_scores.append(max(scores))
80
+ max_score = max(max(scores), max_score)
81
+
82
+ all_scores.remove(max_score)
83
+ mean_score = sum(all_scores) / len(all_scores)
84
+
85
+ for c in range(self.num_segments):
86
+ cluster_mask = np.zeros_like(clusters)
87
+ cluster_mask[clusters == c] = 1
88
+ score_maps = [cluster_mask * normalized_nouns_maps[:, :, i] for i in range(len(nouns_indices))]
89
+ scores = [score_map.sum() / cluster_mask.sum() for score_map in score_maps]
90
+ result[c] = self.nouns[np.argmax(np.array(scores))] if max(scores) > 1.4 * mean_score else "BG"
91
+
92
+ return result
93
+
94
+ def create_mask(self, clusters, cross_attention, attn_index):
95
+ cluster2noun = self.cluster2noun(clusters, cross_attention, attn_index)
96
+ mask = clusters.copy()
97
+ obj_segments = [c for c in cluster2noun if cluster2noun[c][1] in self.object_nouns]
98
+ for c in range(self.num_segments):
99
+ mask[clusters == c] = 1 if c in obj_segments else 0
100
+ return torch.from_numpy(mask).to("cuda")
101
+
102
+ def get_object_masks(self) -> Tuple[torch.Tensor]:
103
+ clusters_style_32, clusters_struct_32 = self.cluster(res=32)
104
+ clusters_style_64, clusters_struct_64 = self.cluster(res=64)
105
+
106
+ mask_style_32 = self.create_mask(clusters_style_32, self.cross_attention_32, STYLE_INDEX)
107
+ mask_struct_32 = self.create_mask(clusters_struct_32, self.cross_attention_32, STRUCT_INDEX)
108
+ mask_style_64 = self.create_mask(clusters_style_64, self.cross_attention_64, STYLE_INDEX)
109
+ mask_struct_64 = self.create_mask(clusters_struct_64, self.cross_attention_64, STRUCT_INDEX)
110
+
111
+ return mask_style_32, mask_struct_32, mask_style_64, mask_struct_64