orpatashnik commited on
Commit
c4e6a63
1 Parent(s): 04d6ff1
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: indigo
5
  colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 3.23.0
8
- app_file: app.py
9
  pinned: false
10
  license: mit
11
  ---
5
  colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 3.23.0
8
+ app_file: gradio_app.py
9
  pinned: false
10
  license: mit
11
  ---
gradio_app.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+ from PIL import Image
7
+
8
+ import nltk
9
+ nltk.download('punkt')
10
+ nltk.download('averaged_perceptron_tagger')
11
+
12
+ from main import LPMConfig, main
13
+
14
+ DESCRIPTION = '''# Localizing Object-level Shape Variations with Text-to-Image Diffusion Models
15
+ This is a demo for our ''Localizing Object-level Shape Variations with Text-to-Image Diffusion Models'' [paper](https://arxiv.org/abs/2303.11306).
16
+ We introduce a method that generates object-level shape variation for a given image.
17
+ This demo allows using a real image as well as a generated image. For a real image, a matching prompt is required.
18
+ '''
19
+
20
+ def main_pipeline(
21
+ prompt: str,
22
+ object_of_interest: str,
23
+ proxy_words: str,
24
+ number_of_variations: int,
25
+ start_prompt_range: int,
26
+ end_prompt_range: int,
27
+ objects_to_preserve: str,
28
+ background_nouns: str,
29
+ seed: int,
30
+ input_image: str):
31
+ prompt = prompt.replace(object_of_interest, '{word}')
32
+ print(number_of_variations)
33
+ print(proxy_words)
34
+ proxy_words = proxy_words.split(',') if proxy_words != '' else []
35
+ objects_to_preserve = objects_to_preserve.split(',') if objects_to_preserve != '' else []
36
+ background_nouns = background_nouns.split(',') if background_nouns != '' else []
37
+ args = LPMConfig(
38
+ seed=seed,
39
+ prompt=prompt,
40
+ object_of_interest=object_of_interest,
41
+ proxy_words=proxy_words,
42
+ number_of_variations=number_of_variations,
43
+ start_prompt_range=start_prompt_range,
44
+ end_prompt_range=end_prompt_range,
45
+ objects_to_preserve=objects_to_preserve,
46
+ background_nouns=background_nouns,
47
+ real_image_path="" if input_image is None else input_image
48
+ )
49
+
50
+ result_images, result_proxy_words = main(args)
51
+ result_images = [im.permute(1, 2, 0).cpu().numpy() for im in result_images]
52
+ result_images = [(im * 255).astype(np.uint8) for im in result_images]
53
+ result_images = [Image.fromarray(im) for im in result_images]
54
+
55
+ return result_images, ",".join(result_proxy_words)
56
+
57
+
58
+ with gr.Blocks(css='style.css') as demo:
59
+ gr.Markdown(DESCRIPTION)
60
+
61
+ with gr.Row():
62
+ with gr.Column():
63
+ input_image = gr.Image(
64
+ label="Input image (optional)",
65
+ type="filepath"
66
+ )
67
+ prompt = gr.Text(
68
+ label='Prompt',
69
+ max_lines=1,
70
+ placeholder='A table below a lamp',
71
+ )
72
+ object_of_interest = gr.Text(
73
+ label='Object of interest',
74
+ max_lines=1,
75
+ placeholder='lamp',
76
+ )
77
+ proxy_words = gr.Text(
78
+ label='Proxy words - words used to obtain variations (a comma-separated list of words, can leave empty)',
79
+ max_lines=1,
80
+ placeholder=''
81
+ )
82
+ number_of_variations = gr.Slider(
83
+ label='Number of variations (used only for automatic proxy-words)',
84
+ minimum=2,
85
+ maximum=30,
86
+ value=20,
87
+ step=1
88
+ )
89
+ start_prompt_range = gr.Slider(
90
+ label='Number of steps before starting shape interval',
91
+ minimum=0,
92
+ maximum=50,
93
+ value=7,
94
+ step=1
95
+ )
96
+ end_prompt_range = gr.Slider(
97
+ label='Number of steps before ending shape interval',
98
+ minimum=1,
99
+ maximum=50,
100
+ value=17,
101
+ step=1
102
+ )
103
+ objects_to_preserve = gr.Text(
104
+ label='Words corresponding to objects to preserve (a comma-separated list of words, can leave empty)',
105
+ max_lines=1,
106
+ placeholder='table',
107
+ )
108
+ background_nouns = gr.Text(
109
+ label='Words corresponding to objects that should be copied from original image (a comma-separated list of words, can leave empty)',
110
+ max_lines=1,
111
+ placeholder='',
112
+ )
113
+ seed = gr.Slider(
114
+ label='Seed',
115
+ minimum=1,
116
+ maximum=100000,
117
+ value=0,
118
+ step=1
119
+ )
120
+
121
+ run_button = gr.Button('Generate')
122
+ with gr.Column():
123
+ result = gr.Gallery(label='Result').style(grid=4)
124
+ proxy_words_result = gr.Text(label='Used proxy words')
125
+
126
+ examples = [
127
+ [
128
+ "hamster eating watermelon on the beach",
129
+ "watermelon",
130
+ "",
131
+ 20,
132
+ 6,
133
+ 16,
134
+ "",
135
+ "hamster,beach",
136
+ 48,
137
+ None
138
+ ],
139
+ [
140
+ "A decorated lamp in the livingroom",
141
+ "lamp",
142
+ "",
143
+ 20,
144
+ 4,
145
+ 14,
146
+ "livingroom",
147
+ "",
148
+ 42,
149
+ None
150
+ ],
151
+ [
152
+ "a snake in the field eats an apple",
153
+ "snake",
154
+ "",
155
+ 20,
156
+ 7,
157
+ 17,
158
+ "apple",
159
+ "apple,field",
160
+ 10,
161
+ None
162
+ ]
163
+ ]
164
+
165
+ gr.Examples(examples=examples,
166
+ inputs=[
167
+ prompt,
168
+ object_of_interest,
169
+ proxy_words,
170
+ number_of_variations,
171
+ start_prompt_range,
172
+ end_prompt_range,
173
+ objects_to_preserve,
174
+ background_nouns,
175
+ seed,
176
+ input_image
177
+ ],
178
+ outputs=[
179
+ result,
180
+ proxy_words_result
181
+ ],
182
+ fn=main_pipeline,
183
+ cache_examples=False)
184
+
185
+
186
+ inputs = [
187
+ prompt,
188
+ object_of_interest,
189
+ proxy_words,
190
+ number_of_variations,
191
+ start_prompt_range,
192
+ end_prompt_range,
193
+ objects_to_preserve,
194
+ background_nouns,
195
+ seed,
196
+ input_image
197
+ ]
198
+ outputs = [
199
+ result,
200
+ proxy_words_result
201
+ ]
202
+ run_button.click(fn=main_pipeline, inputs=inputs, outputs=outputs)
203
+
204
+ demo.queue(max_size=50).launch(share=False)
main.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from dataclasses import dataclass, field
4
+ from typing import List
5
+
6
+ import pyrallis
7
+ import torch
8
+ from torch.utils.data import DataLoader
9
+ from torchvision.utils import save_image
10
+ from torchvision.transforms import ToTensor
11
+ from tqdm import tqdm
12
+
13
+ from src.prompt_to_prompt_controllers import AttentionStore, AttentionReplace
14
+ from src.null_text_inversion import invert_image
15
+ from src.prompt_utils import get_proxy_prompts
16
+ from src.prompt_mixing import PromptMixing
17
+ from src.diffusion_model_wrapper import DiffusionModelWrapper, get_stable_diffusion_model, get_stable_diffusion_config, \
18
+ generate_original_image
19
+
20
+
21
+ def save_args_dict(args, similar_words):
22
+ exp_path = os.path.join(args.exp_dir, args.prompt.replace(' ', '-'), f"seed={args.seed}_{args.exp_name}")
23
+ os.makedirs(exp_path, exist_ok=True)
24
+
25
+ args_dict = vars(args)
26
+ args_dict['similar_words'] = similar_words
27
+ with open(os.path.join(exp_path, "opt.json"), 'w') as fp:
28
+ json.dump(args_dict, fp, sort_keys=True, indent=4)
29
+
30
+ return exp_path
31
+
32
+
33
+ def main(args):
34
+ ldm_stable = get_stable_diffusion_model(args)
35
+ ldm_stable_config = get_stable_diffusion_config(args)
36
+
37
+ similar_words, prompts, another_prompts = get_proxy_prompts(args, ldm_stable)
38
+ exp_path = save_args_dict(args, similar_words)
39
+
40
+ images = []
41
+ x_t = None
42
+ uncond_embeddings = None
43
+
44
+ if args.real_image_path != "":
45
+ x_t, uncond_embeddings = invert_image(args, ldm_stable, ldm_stable_config, prompts, exp_path)
46
+
47
+ image, x_t, orig_all_latents, orig_mask, average_attention = generate_original_image(args, ldm_stable, ldm_stable_config, prompts, x_t, uncond_embeddings)
48
+ save_image(ToTensor()(image[0]), f"{exp_path}/{similar_words[0]}.jpg")
49
+ save_image(torch.from_numpy(orig_mask).float(), f"{exp_path}/{similar_words[0]}_mask.jpg")
50
+ images.append(image[0])
51
+
52
+ object_of_interest_index = args.prompt.split().index('{word}') + 1
53
+ pm = PromptMixing(args, object_of_interest_index, average_attention)
54
+
55
+ do_other_obj_self_attn_masking = len(args.objects_to_preserve) > 0 and args.end_preserved_obj_self_attn_masking > 0
56
+ do_self_or_cross_attn_inject = args.cross_attn_inject_steps != 0.0 or args.self_attn_inject_steps != 0.0
57
+ if do_other_obj_self_attn_masking:
58
+ print("Do self attn other obj masking")
59
+ if do_self_or_cross_attn_inject:
60
+ print(f'Do self attn inject for {args.self_attn_inject_steps} steps')
61
+ print(f'Do cross attn inject for {args.cross_attn_inject_steps} steps')
62
+
63
+ another_prompts_dataloader = DataLoader(another_prompts[1:], batch_size=args.batch_size, shuffle=False)
64
+
65
+ for another_prompt_batch in tqdm(another_prompts_dataloader):
66
+ batch_size = len(another_prompt_batch["word"])
67
+ batch_prompts = prompts * batch_size
68
+ batch_another_prompt = another_prompt_batch["prompt"]
69
+ if do_self_or_cross_attn_inject or do_other_obj_self_attn_masking:
70
+ batch_prompts.append(prompts[0])
71
+ batch_another_prompt.insert(0, prompts[0])
72
+
73
+ if do_self_or_cross_attn_inject:
74
+ controller = AttentionReplace(batch_another_prompt, ldm_stable.tokenizer, ldm_stable.device,
75
+ ldm_stable_config["low_resource"], ldm_stable_config["num_diffusion_steps"],
76
+ cross_replace_steps=args.cross_attn_inject_steps,
77
+ self_replace_steps=args.self_attn_inject_steps)
78
+ else:
79
+ controller = AttentionStore(ldm_stable_config["low_resource"])
80
+
81
+ diffusion_model_wrapper = DiffusionModelWrapper(args, ldm_stable, ldm_stable_config, controller, prompt_mixing=pm)
82
+ with torch.no_grad():
83
+ image, x_t, _, mask = diffusion_model_wrapper.forward(batch_prompts, latent=x_t, other_prompt=batch_another_prompt,
84
+ post_background=args.background_post_process, orig_all_latents=orig_all_latents,
85
+ orig_mask=orig_mask, uncond_embeddings=uncond_embeddings)
86
+
87
+ for i in range(batch_size):
88
+ image_index = i + 1 if do_self_or_cross_attn_inject or do_other_obj_self_attn_masking else i
89
+ save_image(ToTensor()(image[image_index]), f"{exp_path}/{another_prompt_batch['word'][i]}.jpg")
90
+ if mask is not None:
91
+ save_image(torch.from_numpy(mask).float(), f"{exp_path}/{another_prompt_batch['word'][i]}_mask.jpg")
92
+ images.append(image[image_index])
93
+
94
+ images = [ToTensor()(image) for image in images]
95
+ save_image(images, f"{exp_path}/grid.jpg", nrow=min(max([i for i in range(2, 8) if len(images) % i == 0]), 8))
96
+ return images, similar_words
97
+
98
+
99
+ @dataclass
100
+ class LPMConfig:
101
+
102
+ # general config
103
+ seed: int = 10
104
+ batch_size: int = 1
105
+ exp_dir: str = "results"
106
+ exp_name: str = ""
107
+ display_images: bool = False
108
+ gpu_id: int = 0
109
+
110
+ # Stable Diffusion config
111
+ auth_token: str = ""
112
+ low_resource: bool = True
113
+ num_diffusion_steps: int = 50
114
+ guidance_scale: float = 7.5
115
+ max_num_words: int = 77
116
+
117
+ # prompt-mixing
118
+ prompt: str = "a {word} in the field eats an apple"
119
+ object_of_interest: str = "snake" # The object for which we generate variations
120
+ proxy_words: List[str] = field(default_factory=lambda :[]) # Leave empty for automatic proxy words
121
+ number_of_variations: int = 20
122
+ start_prompt_range: int = 7 # Number of steps to begin prompt-mixing
123
+ end_prompt_range: int = 17 # Number of steps to finish prompt-mixing
124
+
125
+ # attention based shape localization
126
+ objects_to_preserve: List[str] = field(default_factory=lambda :[]) # Objects for which apply attention based shape localization
127
+ remove_obj_from_self_mask: bool = True # If set to True, removes the object of interest from the self-attention mask
128
+ obj_pixels_injection_threshold: float = 0.05
129
+ end_preserved_obj_self_attn_masking: int = 40
130
+
131
+ # real image
132
+ real_image_path: str = ""
133
+
134
+ # controllable background preservation
135
+ background_post_process: bool = True
136
+ background_nouns: List[str] = field(default_factory=lambda :[]) # Objects to take from the original image in addition to the background
137
+ num_segments: int = 5 # Number of clusters for the segmentation
138
+ background_segment_threshold: float = 0.3 # Threshold for the segments labeling
139
+ background_blend_timestep: int = 35 # Number of steps before background blending
140
+
141
+ # other
142
+ cross_attn_inject_steps: float = 0.0
143
+ self_attn_inject_steps: float = 0.0
144
+
145
+
146
+ if __name__ == '__main__':
147
+ args = pyrallis.parse(config_class=LPMConfig)
148
+
149
+ print(args)
150
+ main(args)
requirements.txt ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.18.0
2
+ anyio==3.6.2
3
+ argon2-cffi==21.3.0
4
+ argon2-cffi-bindings==21.2.0
5
+ asttokens==2.2.1
6
+ attrs==22.2.0
7
+ backcall==0.2.0
8
+ backports.functools-lru-cache==1.6.4
9
+ beautifulsoup4==4.11.2
10
+ bleach==6.0.0
11
+ brotlipy==0.7.0
12
+ certifi==2022.12.7
13
+ cffi==1.15.1
14
+ chardet==4.0.0
15
+ charset-normalizer==2.0.4
16
+ click==8.1.3
17
+ comm==0.1.2
18
+ contourpy==1.0.5
19
+ cryptography==38.0.4
20
+ cycler==0.11.0
21
+ debugpy==1.5.1
22
+ decorator==5.1.1
23
+ defusedxml==0.7.1
24
+ diffusers==0.10.2
25
+ entrypoints==0.4
26
+ executing==1.2.0
27
+ fastjsonschema==2.16.2
28
+ filelock==3.10.4
29
+ flit_core==3.6.0
30
+ fonttools==4.25.0
31
+ huggingface-hub==0.13.3
32
+ idna==3.4
33
+ importlib-metadata==6.0.0
34
+ importlib-resources==5.10.2
35
+ ipykernel==6.19.2
36
+ ipython==8.8.0
37
+ ipython-genutils==0.2.0
38
+ jedi==0.18.2
39
+ Jinja2==3.1.2
40
+ joblib==1.2.0
41
+ jsonschema==4.17.3
42
+ jupyter-client==7.3.4
43
+ jupyter_core==4.12.0
44
+ jupyter-server==1.23.5
45
+ jupyterlab-pygments==0.2.2
46
+ kiwisolver==1.4.4
47
+ MarkupSafe==2.1.2
48
+ matplotlib==3.6.2
49
+ matplotlib-inline==0.1.6
50
+ mistune==2.0.5
51
+ mkl-fft==1.3.1
52
+ mkl-random==1.2.2
53
+ mkl-service==2.4.0
54
+ munkres==1.1.4
55
+ mypy-extensions==1.0.0
56
+ nbclassic==0.5.1
57
+ nbclient==0.7.2
58
+ nbconvert==7.2.9
59
+ nbformat==5.7.3
60
+ nest-asyncio==1.5.6
61
+ nltk==3.8.1
62
+ notebook==6.5.2
63
+ notebook_shim==0.2.2
64
+ numpy==1.23.5
65
+ opencv-python==4.7.0.72
66
+ packaging==23.0
67
+ pandocfilters==1.5.0
68
+ parso==0.8.3
69
+ pexpect==4.8.0
70
+ pickleshare==0.7.5
71
+ Pillow==9.3.0
72
+ pip==23.0.1
73
+ pkgutil_resolve_name==1.3.10
74
+ ply==3.11
75
+ prometheus-client==0.16.0
76
+ prompt-toolkit==3.0.36
77
+ psutil==5.9.4
78
+ ptyprocess==0.7.0
79
+ pure-eval==0.2.2
80
+ pycparser==2.21
81
+ Pygments==2.14.0
82
+ pyOpenSSL==22.0.0
83
+ pyparsing==3.0.9
84
+ PyQt5-sip==12.11.0
85
+ pyrallis==0.3.1
86
+ pyrsistent==0.19.3
87
+ PySocks==1.7.1
88
+ python-dateutil==2.8.2
89
+ PyYAML==6.0
90
+ pyzmq==25.0.0
91
+ regex==2023.3.23
92
+ requests==2.28.1
93
+ scikit-learn==1.2.2
94
+ scipy==1.10.1
95
+ Send2Trash==1.8.0
96
+ setuptools==65.6.3
97
+ sip==6.6.2
98
+ six==1.16.0
99
+ sniffio==1.3.0
100
+ soupsieve==2.3.2.post1
101
+ stack-data==0.6.2
102
+ terminado==0.17.1
103
+ threadpoolctl==3.1.0
104
+ tinycss2==1.2.1
105
+ tokenizers==0.13.2
106
+ toml==0.10.2
107
+ torch==1.13.1
108
+ torchaudio==0.13.1
109
+ torchvision==0.14.1
110
+ tornado==6.2
111
+ tqdm==4.65.0
112
+ traitlets==5.7.1
113
+ transformers==4.25.1
114
+ typing_extensions==4.4.0
115
+ typing-inspect==0.8.0
116
+ urllib3==1.26.14
117
+ wcwidth==0.2.6
118
+ webencodings==0.5.1
119
+ websocket-client==1.5.1
120
+ wheel==0.37.1
121
+ zipp==3.11.0
src/attention_based_segmentation.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nltk
2
+ from sklearn.cluster import KMeans
3
+ import numpy as np
4
+
5
+ from src.attention_utils import aggregate_attention
6
+
7
+
8
+ class Segmentor:
9
+
10
+ def __init__(self, controller, prompts, num_segments, background_segment_threshold, res=32, background_nouns=[]):
11
+ self.controller = controller
12
+ self.prompts = prompts
13
+ self.num_segments = num_segments
14
+ self.background_segment_threshold = background_segment_threshold
15
+ self.resolution = res
16
+ self.background_nouns = background_nouns
17
+
18
+ self.self_attention = aggregate_attention(controller, res=32, from_where=("up", "down"), prompts=prompts,
19
+ is_cross=False, select=len(prompts) - 1)
20
+ self.cross_attention = aggregate_attention(controller, res=16, from_where=("up", "down"), prompts=prompts,
21
+ is_cross=True, select=len(prompts) - 1)
22
+ tokenized_prompt = nltk.word_tokenize(prompts[-1])
23
+ self.nouns = [(i, word) for (i, (word, pos)) in enumerate(nltk.pos_tag(tokenized_prompt)) if pos[:2] == 'NN']
24
+
25
+ def __call__(self, *args, **kwargs):
26
+ clusters = self.cluster()
27
+ cluster2noun = self.cluster2noun(clusters)
28
+ return cluster2noun
29
+
30
+ def cluster(self):
31
+ np.random.seed(1)
32
+ resolution = self.self_attention.shape[0]
33
+ attn = self.self_attention.cpu().numpy().reshape(resolution ** 2, resolution ** 2)
34
+ kmeans = KMeans(n_clusters=self.num_segments, n_init=10).fit(attn)
35
+ clusters = kmeans.labels_
36
+ clusters = clusters.reshape(resolution, resolution)
37
+ return clusters
38
+
39
+ def cluster2noun(self, clusters):
40
+ result = {}
41
+ nouns_indices = [index for (index, word) in self.nouns]
42
+ nouns_maps = self.cross_attention.cpu().numpy()[:, :, [i + 1 for i in nouns_indices]]
43
+ normalized_nouns_maps = np.zeros_like(nouns_maps).repeat(2, axis=0).repeat(2, axis=1)
44
+ for i in range(nouns_maps.shape[-1]):
45
+ curr_noun_map = nouns_maps[:, :, i].repeat(2, axis=0).repeat(2, axis=1)
46
+ normalized_nouns_maps[:, :, i] = (curr_noun_map - np.abs(curr_noun_map.min())) / curr_noun_map.max()
47
+ for c in range(self.num_segments):
48
+ cluster_mask = np.zeros_like(clusters)
49
+ cluster_mask[clusters == c] = 1
50
+ score_maps = [cluster_mask * normalized_nouns_maps[:, :, i] for i in range(len(nouns_indices))]
51
+ scores = [score_map.sum() / cluster_mask.sum() for score_map in score_maps]
52
+ result[c] = self.nouns[np.argmax(np.array(scores))] if max(scores) > self.background_segment_threshold else "BG"
53
+ return result
54
+
55
+ def get_background_mask(self, obj_token_index):
56
+ clusters = self.cluster()
57
+ cluster2noun = self.cluster2noun(clusters)
58
+ mask = clusters.copy()
59
+ obj_segments = [c for c in cluster2noun if cluster2noun[c][0] == obj_token_index - 1]
60
+ background_segments = [c for c in cluster2noun if cluster2noun[c] == "BG" or cluster2noun[c][1] in self.background_nouns]
61
+ for c in range(self.num_segments):
62
+ if c in background_segments and c not in obj_segments:
63
+ mask[clusters == c] = 0
64
+ else:
65
+ mask[clusters == c] = 1
66
+ return mask
67
+
src/attention_utils.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from typing import Tuple, List
4
+ from cv2 import putText, getTextSize, FONT_HERSHEY_SIMPLEX
5
+ # import matplotlib.pyplot as plt
6
+ from PIL import Image
7
+
8
+ from src.prompt_to_prompt_controllers import AttentionStore
9
+
10
+ def aggregate_attention(attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int, prompts):
11
+ out = []
12
+ attention_maps = attention_store.get_average_attention()
13
+ num_pixels = res ** 2
14
+ for location in from_where:
15
+ for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
16
+ if item.shape[1] == num_pixels:
17
+ cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
18
+ out.append(cross_maps)
19
+ out = torch.cat(out, dim=0)
20
+ out = out.sum(0) / out.shape[0]
21
+ return out.cpu()
22
+
23
+
24
+ def show_cross_attention(attention_store: AttentionStore, res: int, from_where: List[str], prompts, tokenizer, select: int = 0):
25
+ tokens = tokenizer.encode(prompts[select])
26
+ decoder = tokenizer.decode
27
+ attention_maps = aggregate_attention(attention_store, res, from_where, True, select, prompts)
28
+ images = []
29
+ for i in range(len(tokens)):
30
+ image = attention_maps[:, :, i]
31
+ image = 255 * image / image.max()
32
+ image = image.unsqueeze(-1).expand(*image.shape, 3)
33
+ image = image.numpy().astype(np.uint8)
34
+ image = np.array(Image.fromarray(image).resize((256, 256)))
35
+ image = text_under_image(image, decoder(int(tokens[i])))
36
+ images.append(image)
37
+ view_images(np.stack(images, axis=0))
38
+
39
+
40
+ def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str],
41
+ max_com=10, select: int = 0):
42
+ attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape(
43
+ (res ** 2, res ** 2))
44
+ u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True))
45
+ images = []
46
+ for i in range(max_com):
47
+ image = vh[i].reshape(res, res)
48
+ image = image - image.min()
49
+ image = 255 * image / image.max()
50
+ image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8)
51
+ image = Image.fromarray(image).resize((256, 256))
52
+ image = np.array(image)
53
+ images.append(image)
54
+ view_images(np.concatenate(images, axis=1))
55
+
56
+
57
+ def view_images(images, num_rows=1, offset_ratio=0.02):
58
+ if type(images) is list:
59
+ num_empty = len(images) % num_rows
60
+ elif images.ndim == 4:
61
+ num_empty = images.shape[0] % num_rows
62
+ else:
63
+ images = [images]
64
+ num_empty = 0
65
+
66
+ empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
67
+ images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
68
+ num_items = len(images)
69
+
70
+ h, w, c = images[0].shape
71
+ offset = int(h * offset_ratio)
72
+ num_cols = num_items // num_rows
73
+ image_ = np.ones((h * num_rows + offset * (num_rows - 1),
74
+ w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
75
+ for i in range(num_rows):
76
+ for j in range(num_cols):
77
+ image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
78
+ i * num_cols + j]
79
+
80
+ pil_img = Image.fromarray(image_)
81
+ display(pil_img)
82
+
83
+
84
+ def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)):
85
+ h, w, c = image.shape
86
+ offset = int(h * .2)
87
+ img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
88
+ font = FONT_HERSHEY_SIMPLEX
89
+ img[:h] = image
90
+ textsize = getTextSize(text, font, 1, 2)[0]
91
+ text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
92
+ putText(img, text, (text_x, text_y ), font, 1, text_color, 2)
93
+ return img
94
+
95
+
96
+ def display(image):
97
+ global display_index
98
+ plt.imshow(image)
99
+ plt.show()
src/diffusion_model_wrapper.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from typing import Optional, List
4
+
5
+ from diffusers import DDIMScheduler, StableDiffusionPipeline
6
+ from tqdm import tqdm
7
+ from cv2 import dilate
8
+
9
+ from src.attention_utils import show_cross_attention
10
+ from src.attention_based_segmentation import Segmentor
11
+ from src.prompt_to_prompt_controllers import DummyController, AttentionStore
12
+
13
+
14
+ def get_stable_diffusion_model(args):
15
+ device = torch.device(f'cuda:{args.gpu_id}') if torch.cuda.is_available() else torch.device('cpu')
16
+ if args.real_image_path != "":
17
+ scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
18
+ ldm_stable = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=args.auth_token, scheduler=scheduler).to(device)
19
+ else:
20
+ ldm_stable = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=args.auth_token).to(device)
21
+
22
+ return ldm_stable
23
+
24
+ def get_stable_diffusion_config(args):
25
+ return {
26
+ "low_resource": args.low_resource,
27
+ "num_diffusion_steps": args.num_diffusion_steps,
28
+ "guidance_scale": args.guidance_scale,
29
+ "max_num_words": args.max_num_words
30
+ }
31
+
32
+
33
+ def generate_original_image(args, ldm_stable, ldm_stable_config, prompts, latent, uncond_embeddings):
34
+ g_cpu = torch.Generator(device=ldm_stable.device).manual_seed(args.seed)
35
+ controller = AttentionStore(ldm_stable_config["low_resource"])
36
+ diffusion_model_wrapper = DiffusionModelWrapper(args, ldm_stable, ldm_stable_config, controller, generator=g_cpu)
37
+ image, x_t, orig_all_latents, _ = diffusion_model_wrapper.forward(prompts,
38
+ latent=latent,
39
+ uncond_embeddings=uncond_embeddings)
40
+ orig_mask = Segmentor(controller, prompts, args.num_segments, args.background_segment_threshold, background_nouns=args.background_nouns)\
41
+ .get_background_mask(args.prompt.split(' ').index("{word}") + 1)
42
+ average_attention = controller.get_average_attention()
43
+ return image, x_t, orig_all_latents, orig_mask, average_attention
44
+
45
+
46
+ class DiffusionModelWrapper:
47
+ def __init__(self, args, model, model_config, controller=None, prompt_mixing=None, generator=None):
48
+ self.args = args
49
+ self.model = model
50
+ self.model_config = model_config
51
+ self.controller = controller
52
+ if self.controller is None:
53
+ self.controller = DummyController()
54
+ self.prompt_mixing = prompt_mixing
55
+ self.device = model.device
56
+ self.generator = generator
57
+
58
+ self.height = 512
59
+ self.width = 512
60
+
61
+ self.diff_step = 0
62
+ self.register_attention_control()
63
+
64
+
65
+ def diffusion_step(self, latents, context, t, other_context=None):
66
+ if self.model_config["low_resource"]:
67
+ self.uncond_pred = True
68
+ noise_pred_uncond = self.model.unet(latents, t, encoder_hidden_states=(context[0], None))["sample"]
69
+ self.uncond_pred = False
70
+ noise_prediction_text = self.model.unet(latents, t, encoder_hidden_states=(context[1], other_context))["sample"]
71
+ else:
72
+ latents_input = torch.cat([latents] * 2)
73
+ noise_pred = self.model.unet(latents_input, t, encoder_hidden_states=(context, other_context))["sample"]
74
+ noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
75
+ noise_pred = noise_pred_uncond + self.model_config["guidance_scale"] * (noise_prediction_text - noise_pred_uncond)
76
+ latents = self.model.scheduler.step(noise_pred, t, latents)["prev_sample"]
77
+ latents = self.controller.step_callback(latents)
78
+ return latents
79
+
80
+
81
+ def latent2image(self, latents):
82
+ latents = 1 / 0.18215 * latents
83
+ image = self.model.vae.decode(latents)['sample']
84
+ image = (image / 2 + 0.5).clamp(0, 1)
85
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
86
+ image = (image * 255).astype(np.uint8)
87
+ return image
88
+
89
+
90
+ def init_latent(self, latent, batch_size):
91
+ if latent is None:
92
+ latent = torch.randn(
93
+ (1, self.model.unet.in_channels, self.height // 8, self.width // 8),
94
+ generator=self.generator, device=self.model.device
95
+ )
96
+ latents = latent.expand(batch_size, self.model.unet.in_channels, self.height // 8, self.width // 8).to(self.device)
97
+ return latent, latents
98
+
99
+
100
+ def register_attention_control(self):
101
+ def ca_forward(model_self, place_in_unet):
102
+ to_out = model_self.to_out
103
+ if type(to_out) is torch.nn.modules.container.ModuleList:
104
+ to_out = model_self.to_out[0]
105
+ else:
106
+ to_out = model_self.to_out
107
+
108
+ def forward(x, context=None, mask=None):
109
+ batch_size, sequence_length, dim = x.shape
110
+ h = model_self.heads
111
+ q = model_self.to_q(x)
112
+ is_cross = context is not None
113
+ context = context if is_cross else (x, None)
114
+
115
+ k = model_self.to_k(context[0])
116
+ if is_cross and self.prompt_mixing is not None:
117
+ v_context = self.prompt_mixing.get_context_for_v(self.diff_step, context[0], context[1])
118
+ v = model_self.to_v(v_context)
119
+ else:
120
+ v = model_self.to_v(context[0])
121
+
122
+ q = model_self.reshape_heads_to_batch_dim(q)
123
+ k = model_self.reshape_heads_to_batch_dim(k)
124
+ v = model_self.reshape_heads_to_batch_dim(v)
125
+
126
+ sim = torch.einsum("b i d, b j d -> b i j", q, k) * model_self.scale
127
+
128
+ if mask is not None:
129
+ mask = mask.reshape(batch_size, -1)
130
+ max_neg_value = -torch.finfo(sim.dtype).max
131
+ mask = mask[:, None, :].repeat(h, 1, 1)
132
+ sim.masked_fill_(~mask, max_neg_value)
133
+
134
+ # attention, what we cannot get enough of
135
+ attn = sim.softmax(dim=-1)
136
+ if self.enbale_attn_controller_changes:
137
+ attn = self.controller(attn, is_cross, place_in_unet)
138
+
139
+ if is_cross and context[1] is not None and self.prompt_mixing is not None:
140
+ attn = self.prompt_mixing.get_cross_attn(self, self.diff_step, attn, place_in_unet, batch_size)
141
+
142
+ if not is_cross and (not self.model_config["low_resource"] or not self.uncond_pred) and self.prompt_mixing is not None:
143
+ attn = self.prompt_mixing.get_self_attn(self, self.diff_step, attn, place_in_unet, batch_size)
144
+
145
+ out = torch.einsum("b i j, b j d -> b i d", attn, v)
146
+ out = model_self.reshape_batch_dim_to_heads(out)
147
+ return to_out(out)
148
+
149
+ return forward
150
+
151
+ def register_recr(net_, count, place_in_unet):
152
+ if net_.__class__.__name__ == 'CrossAttention':
153
+ net_.forward = ca_forward(net_, place_in_unet)
154
+ return count + 1
155
+ elif hasattr(net_, 'children'):
156
+ for net__ in net_.children():
157
+ count = register_recr(net__, count, place_in_unet)
158
+ return count
159
+
160
+ cross_att_count = 0
161
+ sub_nets = self.model.unet.named_children()
162
+ for net in sub_nets:
163
+ if "down" in net[0]:
164
+ cross_att_count += register_recr(net[1], 0, "down")
165
+ elif "up" in net[0]:
166
+ cross_att_count += register_recr(net[1], 0, "up")
167
+ elif "mid" in net[0]:
168
+ cross_att_count += register_recr(net[1], 0, "mid")
169
+ self.controller.num_att_layers = cross_att_count
170
+
171
+
172
+ def get_text_embedding(self, prompt: List[str], max_length=None, truncation=True):
173
+ text_input = self.model.tokenizer(
174
+ prompt,
175
+ padding="max_length",
176
+ max_length=self.model.tokenizer.model_max_length if max_length is None else max_length,
177
+ truncation=truncation,
178
+ return_tensors="pt",
179
+ )
180
+ text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.device))[0]
181
+ max_length = text_input.input_ids.shape[-1]
182
+ return text_embeddings, max_length
183
+
184
+
185
+ @torch.no_grad()
186
+ def forward(self, prompt: List[str], latent: Optional[torch.FloatTensor] = None,
187
+ other_prompt: List[str] = None, post_background = False, orig_all_latents = None, orig_mask = None,
188
+ uncond_embeddings=None, start_time=51, return_type='image'):
189
+ self.enbale_attn_controller_changes = True
190
+ batch_size = len(prompt)
191
+
192
+ text_embeddings, max_length = self.get_text_embedding(prompt)
193
+ if uncond_embeddings is None:
194
+ uncond_embeddings_, _ = self.get_text_embedding([""] * batch_size, max_length=max_length, truncation=False)
195
+ else:
196
+ uncond_embeddings_ = None
197
+
198
+ other_context = None
199
+ if other_prompt is not None:
200
+ other_text_embeddings, _ = self.get_text_embedding(other_prompt)
201
+ other_context = other_text_embeddings
202
+
203
+ latent, latents = self.init_latent(latent, batch_size)
204
+
205
+ # set timesteps
206
+ self.model.scheduler.set_timesteps(self.model_config["num_diffusion_steps"])
207
+ all_latents = []
208
+
209
+ object_mask = None
210
+ self.diff_step = 0
211
+ for i, t in enumerate(tqdm(self.model.scheduler.timesteps[-start_time:])):
212
+ if uncond_embeddings_ is None:
213
+ context = [uncond_embeddings[i].expand(*text_embeddings.shape), text_embeddings]
214
+ else:
215
+ context = [uncond_embeddings_, text_embeddings]
216
+ if not self.model_config["low_resource"]:
217
+ context = torch.cat(context)
218
+
219
+ self.down_cross_index = 0
220
+ self.mid_cross_index = 0
221
+ self.up_cross_index = 0
222
+ latents = self.diffusion_step(latents, context, t, other_context)
223
+
224
+ if post_background and self.diff_step == self.args.background_blend_timestep:
225
+ object_mask = Segmentor(self.controller,
226
+ prompt,
227
+ self.args.num_segments,
228
+ self.args.background_segment_threshold,
229
+ background_nouns=self.args.background_nouns)\
230
+ .get_background_mask(self.args.prompt.split(' ').index("{word}") + 1)
231
+ self.enbale_attn_controller_changes = False
232
+ mask = object_mask.astype(np.bool8) + orig_mask.astype(np.bool8)
233
+ mask = torch.from_numpy(mask).float().cuda()
234
+ shape = (1, 1, mask.shape[0], mask.shape[1])
235
+ mask = torch.nn.Upsample(size=(64, 64), mode='nearest')(mask.view(shape))
236
+ mask_eroded = dilate(mask.cpu().numpy()[0, 0], np.ones((3, 3), np.uint8), iterations=1)
237
+ mask = torch.from_numpy(mask_eroded).float().cuda().view(1, 1, 64, 64)
238
+ latents = mask * latents + (1 - mask) * orig_all_latents[self.diff_step]
239
+
240
+ all_latents.append(latents)
241
+ self.diff_step += 1
242
+
243
+ if return_type == 'image':
244
+ image = self.latent2image(latents)
245
+ else:
246
+ image = latents
247
+
248
+ return image, latent, all_latents, object_mask
249
+
250
+
251
+ def show_last_cross_attention(self, res: int, from_where: List[str], prompts, select: int = 0):
252
+ show_cross_attention(self.controller, res, from_where, prompts, tokenizer=self.model.tokenizer, select=select)
src/null_text_inversion.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ from torchvision.transforms import ToTensor
4
+ from torchvision.utils import save_image
5
+ from tqdm import tqdm
6
+ import torch
7
+ from torch.optim.adam import Adam
8
+ import torch.nn.functional as nnf
9
+ import numpy as np
10
+ from PIL import Image
11
+
12
+
13
+ def load_512(image_path, left=0, right=0, top=0, bottom=0):
14
+ if type(image_path) is str:
15
+ image = np.array(Image.open(image_path))[:, :, :3]
16
+ else:
17
+ image = image_path
18
+ h, w, c = image.shape
19
+ left = min(left, w-1)
20
+ right = min(right, w - left - 1)
21
+ top = min(top, h - left - 1)
22
+ bottom = min(bottom, h - top - 1)
23
+ image = image[top:h-bottom, left:w-right]
24
+ h, w, c = image.shape
25
+ if h < w:
26
+ offset = (w - h) // 2
27
+ image = image[:, offset:offset + h]
28
+ elif w < h:
29
+ offset = (h - w) // 2
30
+ image = image[offset:offset + w]
31
+ image = np.array(Image.fromarray(image).resize((512, 512)))
32
+ return image
33
+
34
+
35
+ def invert_image(args, ldm_stable, ldm_stable_config, prompts, exp_path):
36
+ print("Start null text inversion")
37
+ null_inversion = NullInversion(ldm_stable, ldm_stable_config)
38
+ (image_gt, image_enc), x_t, uncond_embeddings = null_inversion.invert(args.real_image_path, prompts[0], offsets=(0,0,0,0), verbose=True)
39
+ save_image(ToTensor()(image_gt), f"{exp_path}/real_image.jpg")
40
+ save_image(ToTensor()(image_enc), f"{exp_path}/image_enc.jpg")
41
+ print("End null text inversion")
42
+ return x_t, uncond_embeddings
43
+
44
+
45
+ class NullInversion:
46
+
47
+ def __init__(self, model, model_config):
48
+ self.model = model
49
+ self.model_config = model_config
50
+ self.tokenizer = self.model.tokenizer
51
+ self.model.scheduler.set_timesteps(self.model_config["num_diffusion_steps"])
52
+ self.prompt = None
53
+ self.context = None
54
+
55
+
56
+ def prev_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]):
57
+ prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
58
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
59
+ alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
60
+ beta_prod_t = 1 - alpha_prod_t
61
+ pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
62
+ pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output
63
+ prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction
64
+ return prev_sample
65
+
66
+ def next_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]):
67
+ timestep, next_timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999), timestep
68
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
69
+ alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep]
70
+ beta_prod_t = 1 - alpha_prod_t
71
+ next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
72
+ next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
73
+ next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
74
+ return next_sample
75
+
76
+ def get_noise_pred_single(self, latents, t, context):
77
+ noise_pred = self.model.unet(latents, t, encoder_hidden_states=context)["sample"]
78
+ return noise_pred
79
+
80
+ def get_noise_pred(self, latents, t, is_forward=True, context=None):
81
+ latents_input = torch.cat([latents] * 2)
82
+ if context is None:
83
+ context = self.context
84
+ guidance_scale = 1 if is_forward else self.model_config["guidance_scale"]
85
+ noise_pred = self.model.unet(latents_input, t, encoder_hidden_states=context)["sample"]
86
+ noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
87
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
88
+ if is_forward:
89
+ latents = self.next_step(noise_pred, t, latents)
90
+ else:
91
+ latents = self.prev_step(noise_pred, t, latents)
92
+ return latents
93
+
94
+ @torch.no_grad()
95
+ def latent2image(self, latents, return_type='np'):
96
+ latents = 1 / 0.18215 * latents.detach()
97
+ image = self.model.vae.decode(latents)['sample']
98
+ if return_type == 'np':
99
+ image = (image / 2 + 0.5).clamp(0, 1)
100
+ image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
101
+ image = (image * 255).astype(np.uint8)
102
+ return image
103
+
104
+ @torch.no_grad()
105
+ def image2latent(self, image):
106
+ with torch.no_grad():
107
+ if type(image) is Image:
108
+ image = np.array(image)
109
+ if type(image) is torch.Tensor and image.dim() == 4:
110
+ latents = image
111
+ else:
112
+ image = torch.from_numpy(image).float() / 127.5 - 1
113
+ image = image.permute(2, 0, 1).unsqueeze(0).to(self.model.device)
114
+ latents = self.model.vae.encode(image)['latent_dist'].mean
115
+ latents = latents * 0.18215
116
+ return latents
117
+
118
+ @torch.no_grad()
119
+ def init_prompt(self, prompt: str):
120
+ uncond_input = self.model.tokenizer(
121
+ [""], padding="max_length", max_length=self.model.tokenizer.model_max_length,
122
+ return_tensors="pt"
123
+ )
124
+ uncond_embeddings = self.model.text_encoder(uncond_input.input_ids.to(self.model.device))[0]
125
+ text_input = self.model.tokenizer(
126
+ [prompt],
127
+ padding="max_length",
128
+ max_length=self.model.tokenizer.model_max_length,
129
+ truncation=True,
130
+ return_tensors="pt",
131
+ )
132
+ text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.model.device))[0]
133
+ self.context = torch.cat([uncond_embeddings, text_embeddings])
134
+ self.prompt = prompt
135
+
136
+ @torch.no_grad()
137
+ def ddim_loop(self, latent):
138
+ uncond_embeddings, cond_embeddings = self.context.chunk(2)
139
+ all_latent = [latent]
140
+ latent = latent.clone().detach()
141
+ for i in tqdm(range(self.model_config["num_diffusion_steps"])):
142
+ t = self.model.scheduler.timesteps[len(self.model.scheduler.timesteps) - i - 1]
143
+ noise_pred = self.get_noise_pred_single(latent, t, cond_embeddings)
144
+ latent = self.next_step(noise_pred, t, latent)
145
+ all_latent.append(latent)
146
+ return all_latent
147
+
148
+ @property
149
+ def scheduler(self):
150
+ return self.model.scheduler
151
+
152
+ @torch.no_grad()
153
+ def ddim_inversion(self, image):
154
+ latent = self.image2latent(image)
155
+ image_rec = self.latent2image(latent)
156
+ ddim_latents = self.ddim_loop(latent)
157
+ return image_rec, ddim_latents
158
+
159
+ def null_optimization(self, latents, num_inner_steps, epsilon):
160
+ uncond_embeddings, cond_embeddings = self.context.chunk(2)
161
+ uncond_embeddings_list = []
162
+ latent_cur = latents[-1]
163
+ with tqdm(total=num_inner_steps * (self.model_config["num_diffusion_steps"])) as bar:
164
+ for i in range(self.model_config["num_diffusion_steps"]):
165
+ uncond_embeddings = uncond_embeddings.clone().detach()
166
+ uncond_embeddings.requires_grad = True
167
+ optimizer = Adam([uncond_embeddings], lr=1e-2 * (1. - i / 100.))
168
+ latent_prev = latents[len(latents) - i - 2]
169
+ t = self.model.scheduler.timesteps[i]
170
+ with torch.no_grad():
171
+ noise_pred_cond = self.get_noise_pred_single(latent_cur, t, cond_embeddings)
172
+ for j in range(num_inner_steps):
173
+ noise_pred_uncond = self.get_noise_pred_single(latent_cur, t, uncond_embeddings)
174
+ noise_pred = noise_pred_uncond + self.model_config["guidance_scale"] * (noise_pred_cond - noise_pred_uncond)
175
+ latents_prev_rec = self.prev_step(noise_pred, t, latent_cur)
176
+ loss = nnf.mse_loss(latents_prev_rec, latent_prev)
177
+ optimizer.zero_grad()
178
+ loss.backward()
179
+ optimizer.step()
180
+ loss_item = loss.item()
181
+ bar.update()
182
+ if loss_item < epsilon + i * 2e-5:
183
+ break
184
+ bar.update(num_inner_steps - j - 1)
185
+ uncond_embeddings_list.append(uncond_embeddings[:1].detach())
186
+ with torch.no_grad():
187
+ context = torch.cat([uncond_embeddings, cond_embeddings])
188
+ latent_cur = self.get_noise_pred(latent_cur, t, False, context)
189
+ # bar.close()
190
+ return uncond_embeddings_list
191
+
192
+ def invert(self, image_path: str, prompt: str, offsets=(0,0,0,0), num_inner_steps=10, early_stop_epsilon=1e-5, verbose=False):
193
+ self.init_prompt(prompt)
194
+ image_gt = load_512(image_path, *offsets)
195
+ if verbose:
196
+ print("DDIM inversion...")
197
+ image_rec, ddim_latents = self.ddim_inversion(image_gt)
198
+ if verbose:
199
+ print("Null-text optimization...")
200
+ uncond_embeddings = self.null_optimization(ddim_latents, num_inner_steps, early_stop_epsilon)
201
+ return (image_gt, image_rec), ddim_latents[-1], uncond_embeddings
src/prompt_mixing.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from scipy.signal import medfilt2d
3
+
4
+ class PromptMixing:
5
+ def __init__(self, args, object_of_interest_index, avg_cross_attn=None):
6
+ self.object_of_interest_index = object_of_interest_index
7
+ self.objects_to_preserve = [args.prompt.split().index(o) + 1 for o in args.objects_to_preserve]
8
+ self.obj_pixels_injection_threshold = args.obj_pixels_injection_threshold
9
+
10
+ self.start_other_prompt_range = args.start_prompt_range
11
+ self.end_other_prompt_range = args.end_prompt_range
12
+
13
+ self.start_cross_attn_replace_range = args.num_diffusion_steps
14
+ self.end_cross_attn_replace_range = args.num_diffusion_steps
15
+
16
+ self.start_self_attn_replace_range = 0
17
+ self.end_self_attn_replace_range = args.end_preserved_obj_self_attn_masking
18
+ self.remove_obj_from_self_mask = args.remove_obj_from_self_mask
19
+ self.avg_cross_attn = avg_cross_attn
20
+
21
+ self.low_resource = args.low_resource
22
+
23
+ def get_context_for_v(self, t, context, other_context):
24
+ if other_context is not None and \
25
+ self.start_other_prompt_range <= t < self.end_other_prompt_range:
26
+ if self.low_resource:
27
+ return other_context
28
+ else:
29
+ v_context = context.clone()
30
+ # first half of context is for the uncoditioned image
31
+ v_context[v_context.shape[0]//2:] = other_context
32
+ return v_context
33
+ else:
34
+ return context
35
+
36
+ def get_cross_attn(self, diffusion_model_wrapper, t, attn, place_in_unet, batch_size):
37
+ if self.start_cross_attn_replace_range <= t < self.end_cross_attn_replace_range:
38
+ if self.low_resource:
39
+ attn[:,:,self.object_of_interest_index] = 0.2 * torch.from_numpy(medfilt2d(attn[:, :, self.object_of_interest_index].cpu().numpy(), kernel_size=3)).to(attn.device) + \
40
+ 0.8 * attn[:, :, self.object_of_interest_index]
41
+ else:
42
+ # first half of attn maps is for the uncoditioned image
43
+ min_h = attn.shape[0] // 2
44
+ attn[min_h:, :, self.object_of_interest_index] = 0.2 * torch.from_numpy(medfilt2d(attn[min_h:, :, self.object_of_interest_index].cpu().numpy(), kernel_size=3)).to(attn.device) + \
45
+ 0.8 * attn[min_h:, :, self.object_of_interest_index]
46
+ return attn
47
+
48
+ def get_self_attn(self, diffusion_model_wrapper, t, attn, place_in_unet, batch_size):
49
+ if attn.shape[1] <= 32 ** 2 and \
50
+ self.avg_cross_attn is not None and \
51
+ self.start_self_attn_replace_range <= t < self.end_self_attn_replace_range:
52
+
53
+ key = f"{place_in_unet}_cross"
54
+ attn_index = getattr(diffusion_model_wrapper, f'{key}_index')
55
+ cr = self.avg_cross_attn[key][attn_index]
56
+ setattr(diffusion_model_wrapper, f'{key}_index', attn_index+1)
57
+
58
+ if self.low_resource:
59
+ attn = self.mask_self_attn_patches(attn, cr, batch_size)
60
+ else:
61
+ # first half of attn maps is for the uncoditioned image
62
+ attn[attn.shape[0]//2:] = self.mask_self_attn_patches(attn[attn.shape[0]//2:], cr, batch_size//2)
63
+
64
+ return attn
65
+
66
+ def mask_self_attn_patches(self, self_attn, cross_attn, batch_size):
67
+ h = self_attn.shape[0] // batch_size
68
+ tokens = self.objects_to_preserve
69
+ obj_token = self.object_of_interest_index
70
+
71
+ normalized_cross_attn = cross_attn - cross_attn.min()
72
+ normalized_cross_attn /= normalized_cross_attn.max()
73
+
74
+ mask = torch.zeros_like(self_attn[0])
75
+ for tk in tokens:
76
+ mask_tk_in = torch.unique((normalized_cross_attn[:,:,tk] > self.obj_pixels_injection_threshold).nonzero(as_tuple=True)[1])
77
+ mask[mask_tk_in, :] = 1
78
+ mask[:, mask_tk_in] = 1
79
+
80
+ if self.remove_obj_from_self_mask:
81
+ obj_patches = torch.unique((normalized_cross_attn[:,:,obj_token] > self.obj_pixels_injection_threshold).nonzero(as_tuple=True)[1])
82
+ mask[obj_patches, :] = 0
83
+ mask[:, obj_patches] = 0
84
+
85
+ self_attn[h:] = self_attn[h:] * (1 - mask) + self_attn[:h].repeat(batch_size - 1, 1, 1) * mask
86
+ return self_attn
src/prompt_to_prompt_controllers.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import abc
4
+ from typing import Optional, Union, Tuple, Dict
5
+ import src.seq_aligner as seq_aligner
6
+
7
+
8
+ class AttentionControl(abc.ABC):
9
+
10
+ def step_callback(self, x_t):
11
+ return x_t
12
+
13
+ def between_steps(self):
14
+ return
15
+
16
+ @property
17
+ def num_uncond_att_layers(self):
18
+ return self.num_att_layers if self.low_resource else 0
19
+
20
+ @abc.abstractmethod
21
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
22
+ raise NotImplementedError
23
+
24
+ def __call__(self, attn, is_cross: bool, place_in_unet: str):
25
+ if self.cur_att_layer >= self.num_uncond_att_layers:
26
+ if self.low_resource:
27
+ attn = self.forward(attn, is_cross, place_in_unet)
28
+ else:
29
+ h = attn.shape[0]
30
+ attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
31
+ self.cur_att_layer += 1
32
+ if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
33
+ self.cur_att_layer = 0
34
+ self.cur_step += 1
35
+ self.between_steps()
36
+ return attn
37
+
38
+ def reset(self):
39
+ self.cur_step = 0
40
+ self.cur_att_layer = 0
41
+
42
+ def __init__(self, low_resource):
43
+ self.cur_step = 0
44
+ self.num_att_layers = -1
45
+ self.cur_att_layer = 0
46
+ self.low_resource = low_resource
47
+
48
+
49
+ class EmptyControl(AttentionControl):
50
+
51
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
52
+ return attn
53
+
54
+
55
+ class DummyController:
56
+ def __call__(self, *args):
57
+ return args[0]
58
+
59
+ def __init__(self):
60
+ self.num_att_layers = 0
61
+
62
+
63
+ class AttentionStore(AttentionControl):
64
+
65
+ @staticmethod
66
+ def get_empty_store():
67
+ return {"down_cross": [], "mid_cross": [], "up_cross": [],
68
+ "down_self": [], "mid_self": [], "up_self": []}
69
+
70
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
71
+ key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
72
+ if attn.shape[1] <= 32 ** 2: # avoid memory overhead
73
+ self.step_store[key].append(attn)
74
+ return attn
75
+
76
+ def between_steps(self):
77
+ if len(self.attention_store) == 0:
78
+ self.attention_store = self.step_store
79
+ else:
80
+ for key in self.attention_store:
81
+ for i in range(len(self.attention_store[key])):
82
+ self.attention_store[key][i] += self.step_store[key][i]
83
+ self.step_store = self.get_empty_store()
84
+
85
+ def get_average_attention(self):
86
+ average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in
87
+ self.attention_store}
88
+ return average_attention
89
+
90
+ def reset(self):
91
+ super(AttentionStore, self).reset()
92
+ self.step_store = self.get_empty_store()
93
+ self.attention_store = {}
94
+
95
+ def __init__(self, low_resource):
96
+ super(AttentionStore, self).__init__(low_resource)
97
+ self.step_store = self.get_empty_store()
98
+ self.attention_store = {}
99
+
100
+
101
+ class AttentionControlEdit(AttentionStore, abc.ABC):
102
+
103
+ def step_callback(self, x_t):
104
+ return x_t
105
+
106
+ def replace_self_attention(self, attn_base, att_replace):
107
+ if att_replace.shape[2] <= 16 ** 2:
108
+ return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
109
+ else:
110
+ return att_replace
111
+
112
+ @abc.abstractmethod
113
+ def replace_cross_attention(self, attn_base, att_replace):
114
+ raise NotImplementedError
115
+
116
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
117
+ super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)
118
+ if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]):
119
+ h = attn.shape[0] // (self.batch_size)
120
+ attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
121
+ attn_base, attn_repalce = attn[0], attn[1:]
122
+ if is_cross:
123
+ alpha_words = self.cross_replace_alpha[self.cur_step]
124
+ attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + (
125
+ 1 - alpha_words) * attn_repalce
126
+ attn[1:] = attn_repalce_new
127
+ else:
128
+ attn[1:] = self.replace_self_attention(attn_base, attn_repalce)
129
+ attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
130
+ return attn
131
+
132
+ def __init__(self, prompts, tokenizer, device, low_resource, num_steps: int,
133
+ cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
134
+ self_replace_steps: Union[float, Tuple[float, float]]):
135
+ super(AttentionControlEdit, self).__init__(low_resource)
136
+ self.batch_size = len(prompts)
137
+ self.tokenizer = tokenizer
138
+ self.cross_replace_alpha = get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps,
139
+ self.tokenizer).to(device)
140
+ if type(self_replace_steps) is float:
141
+ self_replace_steps = 0, self_replace_steps
142
+ self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
143
+
144
+
145
+ class AttentionReplace(AttentionControlEdit):
146
+
147
+ def replace_cross_attention(self, attn_base, att_replace):
148
+ return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper.to(attn_base.dtype))
149
+
150
+ def __init__(self, prompts, tokenizer, device, low_resource, num_steps: int, cross_replace_steps: float, self_replace_steps: float):
151
+ super(AttentionReplace, self).__init__(prompts, tokenizer, device, low_resource, num_steps, cross_replace_steps, self_replace_steps)
152
+ self.mapper = seq_aligner.get_replacement_mapper(prompts, self.tokenizer).to(device)
153
+
154
+
155
+ def get_word_inds(text: str, word_place: int, tokenizer):
156
+ split_text = text.split(" ")
157
+ if type(word_place) is str:
158
+ word_place = [i for i, word in enumerate(split_text) if word_place == word]
159
+ elif type(word_place) is int:
160
+ word_place = [word_place]
161
+ out = []
162
+ if len(word_place) > 0:
163
+ words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
164
+ cur_len, ptr = 0, 0
165
+
166
+ for i in range(len(words_encode)):
167
+ cur_len += len(words_encode[i])
168
+ if ptr in word_place:
169
+ out.append(i + 1)
170
+ if cur_len >= len(split_text[ptr]):
171
+ ptr += 1
172
+ cur_len = 0
173
+ return np.array(out)
174
+
175
+
176
+ def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, word_inds: Optional[torch.Tensor]=None):
177
+ if type(bounds) is float:
178
+ bounds = 0, bounds
179
+ start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
180
+ if word_inds is None:
181
+ word_inds = torch.arange(alpha.shape[2])
182
+ alpha[: start, prompt_ind, word_inds] = 0
183
+ alpha[start: end, prompt_ind, word_inds] = 1
184
+ alpha[end:, prompt_ind, word_inds] = 0
185
+ return alpha
186
+
187
+
188
+ def get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
189
+ tokenizer, max_num_words=77):
190
+ if type(cross_replace_steps) is not dict:
191
+ cross_replace_steps = {"default_": cross_replace_steps}
192
+ if "default_" not in cross_replace_steps:
193
+ cross_replace_steps["default_"] = (0., 1.)
194
+ alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
195
+ for i in range(len(prompts) - 1):
196
+ alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"],
197
+ i)
198
+ for key, item in cross_replace_steps.items():
199
+ if key != "default_":
200
+ inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
201
+ for i, ind in enumerate(inds):
202
+ if len(ind) > 0:
203
+ alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
204
+ alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words) # time, batch, heads, pixels, words
205
+ return alpha_time_words
src/prompt_utils.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+
6
+
7
+ def get_topk_similar_words(model, prompt, base_word, vocab, k=30):
8
+ text_input = model.tokenizer(
9
+ [prompt.format(word=base_word)],
10
+ padding="max_length",
11
+ max_length=model.tokenizer.model_max_length,
12
+ truncation=True,
13
+ return_tensors="pt",
14
+ )
15
+ with torch.no_grad():
16
+ encoder_output = model.text_encoder(text_input.input_ids.to(model.device))
17
+ full_prompt_embedding = encoder_output.pooler_output
18
+ full_prompt_embedding = full_prompt_embedding / full_prompt_embedding.norm(p=2, dim=-1, keepdim=True)
19
+
20
+ prompts = [prompt.format(word=word) for word in vocab]
21
+ batch_size = 1000
22
+ all_prompts_embeddings = []
23
+ for i in tqdm(range(0, len(prompts), batch_size)):
24
+ curr_prompts = prompts[i:i + batch_size]
25
+ with torch.no_grad():
26
+ text_input = model.tokenizer(
27
+ curr_prompts,
28
+ padding="max_length",
29
+ max_length=model.tokenizer.model_max_length,
30
+ truncation=True,
31
+ return_tensors="pt",
32
+ )
33
+ curr_embeddings = model.text_encoder(text_input.input_ids.to(model.device)).pooler_output
34
+ all_prompts_embeddings.append(curr_embeddings)
35
+
36
+ all_prompts_embeddings = torch.cat(all_prompts_embeddings)
37
+ all_prompts_embeddings = all_prompts_embeddings / all_prompts_embeddings.norm(p=2, dim=-1, keepdim=True)
38
+ prompts_similarities = all_prompts_embeddings.matmul(full_prompt_embedding.view(-1, 1))
39
+ sorted_prompts_similarities = np.flip(prompts_similarities.cpu().numpy().reshape(-1).argsort())
40
+
41
+ print(f"prompt: {prompt}")
42
+ print(f"initial word: {base_word}")
43
+ print(f"TOP {k} SIMILAR WORDS:")
44
+ similar_words = [vocab[index] for index in sorted_prompts_similarities[:k]]
45
+ print(similar_words)
46
+ return similar_words
47
+
48
+ def get_proxy_words(args, ldm_stable):
49
+ if len(args.proxy_words) > 0:
50
+ return [args.object_of_interest] + args.proxy_words
51
+ vocab = list(json.load(open("vocab.json")).keys())
52
+ vocab = [word for word in vocab if word.isalpha() and len(word) > 1]
53
+ filtered_vocab = get_topk_similar_words(ldm_stable, "a photo of a {word}", args.object_of_interest, vocab, k=50)
54
+ proxy_words = get_topk_similar_words(ldm_stable, args.prompt, args.object_of_interest, filtered_vocab, k=args.number_of_variations)
55
+ if proxy_words[0] != args.object_of_interest:
56
+ proxy_words = [args.object_of_interest] + proxy_words
57
+
58
+ return proxy_words
59
+
60
+ def get_proxy_prompts(args, ldm_stable):
61
+ proxy_words = get_proxy_words(args, ldm_stable)
62
+ prompts = [args.prompt.format(word=args.object_of_interest)]
63
+ proxy_prompts = [{"word": word, "prompt": args.prompt.format(word=word)} for word in proxy_words]
64
+ return proxy_words, prompts, proxy_prompts
src/seq_aligner.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import torch
15
+ import numpy as np
16
+
17
+
18
+ class ScoreParams:
19
+
20
+ def __init__(self, gap, match, mismatch):
21
+ self.gap = gap
22
+ self.match = match
23
+ self.mismatch = mismatch
24
+
25
+ def mis_match_char(self, x, y):
26
+ if x != y:
27
+ return self.mismatch
28
+ else:
29
+ return self.match
30
+
31
+
32
+ def get_matrix(size_x, size_y, gap):
33
+ matrix = []
34
+ for i in range(len(size_x) + 1):
35
+ sub_matrix = []
36
+ for j in range(len(size_y) + 1):
37
+ sub_matrix.append(0)
38
+ matrix.append(sub_matrix)
39
+ for j in range(1, len(size_y) + 1):
40
+ matrix[0][j] = j*gap
41
+ for i in range(1, len(size_x) + 1):
42
+ matrix[i][0] = i*gap
43
+ return matrix
44
+
45
+
46
+ def get_matrix(size_x, size_y, gap):
47
+ matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32)
48
+ matrix[0, 1:] = (np.arange(size_y) + 1) * gap
49
+ matrix[1:, 0] = (np.arange(size_x) + 1) * gap
50
+ return matrix
51
+
52
+
53
+ def get_traceback_matrix(size_x, size_y):
54
+ matrix = np.zeros((size_x + 1, size_y +1), dtype=np.int32)
55
+ matrix[0, 1:] = 1
56
+ matrix[1:, 0] = 2
57
+ matrix[0, 0] = 4
58
+ return matrix
59
+
60
+
61
+ def global_align(x, y, score):
62
+ matrix = get_matrix(len(x), len(y), score.gap)
63
+ trace_back = get_traceback_matrix(len(x), len(y))
64
+ for i in range(1, len(x) + 1):
65
+ for j in range(1, len(y) + 1):
66
+ left = matrix[i, j - 1] + score.gap
67
+ up = matrix[i - 1, j] + score.gap
68
+ diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1])
69
+ matrix[i, j] = max(left, up, diag)
70
+ if matrix[i, j] == left:
71
+ trace_back[i, j] = 1
72
+ elif matrix[i, j] == up:
73
+ trace_back[i, j] = 2
74
+ else:
75
+ trace_back[i, j] = 3
76
+ return matrix, trace_back
77
+
78
+
79
+ def get_aligned_sequences(x, y, trace_back):
80
+ x_seq = []
81
+ y_seq = []
82
+ i = len(x)
83
+ j = len(y)
84
+ mapper_y_to_x = []
85
+ while i > 0 or j > 0:
86
+ if trace_back[i, j] == 3:
87
+ x_seq.append(x[i-1])
88
+ y_seq.append(y[j-1])
89
+ i = i-1
90
+ j = j-1
91
+ mapper_y_to_x.append((j, i))
92
+ elif trace_back[i][j] == 1:
93
+ x_seq.append('-')
94
+ y_seq.append(y[j-1])
95
+ j = j-1
96
+ mapper_y_to_x.append((j, -1))
97
+ elif trace_back[i][j] == 2:
98
+ x_seq.append(x[i-1])
99
+ y_seq.append('-')
100
+ i = i-1
101
+ elif trace_back[i][j] == 4:
102
+ break
103
+ mapper_y_to_x.reverse()
104
+ return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64)
105
+
106
+
107
+ def get_mapper(x: str, y: str, tokenizer, max_len=77):
108
+ x_seq = tokenizer.encode(x)
109
+ y_seq = tokenizer.encode(y)
110
+ score = ScoreParams(0, 1, -1)
111
+ matrix, trace_back = global_align(x_seq, y_seq, score)
112
+ mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1]
113
+ alphas = torch.ones(max_len)
114
+ alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float()
115
+ mapper = torch.zeros(max_len, dtype=torch.int64)
116
+ mapper[:mapper_base.shape[0]] = mapper_base[:, 1]
117
+ mapper[mapper_base.shape[0]:] = len(y_seq) + torch.arange(max_len - len(y_seq))
118
+ return mapper, alphas
119
+
120
+
121
+ def get_refinement_mapper(prompts, tokenizer, max_len=77):
122
+ x_seq = prompts[0]
123
+ mappers, alphas = [], []
124
+ for i in range(1, len(prompts)):
125
+ mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len)
126
+ mappers.append(mapper)
127
+ alphas.append(alpha)
128
+ return torch.stack(mappers), torch.stack(alphas)
129
+
130
+
131
+ def get_word_inds(text: str, word_place: int, tokenizer):
132
+ split_text = text.split(" ")
133
+ if type(word_place) is str:
134
+ word_place = [i for i, word in enumerate(split_text) if word_place == word]
135
+ elif type(word_place) is int:
136
+ word_place = [word_place]
137
+ out = []
138
+ if len(word_place) > 0:
139
+ words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
140
+ cur_len, ptr = 0, 0
141
+
142
+ for i in range(len(words_encode)):
143
+ cur_len += len(words_encode[i])
144
+ if ptr in word_place:
145
+ out.append(i + 1)
146
+ if cur_len >= len(split_text[ptr]):
147
+ ptr += 1
148
+ cur_len = 0
149
+ return np.array(out)
150
+
151
+
152
+ def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77):
153
+ words_x = x.split(' ')
154
+ words_y = y.split(' ')
155
+ if len(words_x) != len(words_y):
156
+ raise ValueError(f"attention replacement edit can only be applied on prompts with the same length"
157
+ f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words.")
158
+ inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]]
159
+ inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace]
160
+ inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace]
161
+ mapper = np.zeros((max_len, max_len))
162
+ i = j = 0
163
+ cur_inds = 0
164
+ while i < max_len and j < max_len:
165
+ if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i:
166
+ inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds]
167
+ if len(inds_source_) == len(inds_target_):
168
+ mapper[inds_source_, inds_target_] = 1
169
+ else:
170
+ ratio = 1 / len(inds_target_)
171
+ for i_t in inds_target_:
172
+ mapper[inds_source_, i_t] = ratio
173
+ cur_inds += 1
174
+ i += len(inds_source_)
175
+ j += len(inds_target_)
176
+ elif cur_inds < len(inds_source):
177
+ mapper[i, j] = 1
178
+ i += 1
179
+ j += 1
180
+ else:
181
+ mapper[j, j] = 1
182
+ i += 1
183
+ j += 1
184
+
185
+ return torch.from_numpy(mapper).float()
186
+
187
+
188
+ def get_replacement_mapper(prompts, tokenizer, max_len=77):
189
+ x_seq = prompts[0]
190
+ mappers = []
191
+ for i in range(1, len(prompts)):
192
+ mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len)
193
+ mappers.append(mapper)
194
+ return torch.stack(mappers)
195
+
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff