ljzycmd commited on
Commit
5fc5efa
·
1 Parent(s): bfee1f8

Add hugging face space demo.

Browse files
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ from diffusers import DDIMScheduler
5
+ from pytorch_lightning import seed_everything
6
+
7
+ from masactrl.diffuser_utils import MasaCtrlPipeline
8
+ from masactrl.masactrl_utils import (AttentionBase,
9
+ regiter_attention_editor_diffusers)
10
+
11
+ torch.set_grad_enabled(False)
12
+
13
+ from gradio_app.image_synthesis_app import create_demo_synthesis
14
+ from gradio_app.real_image_editing_app import create_demo_editing
15
+
16
+ from gradio_app.app_utils import global_context
17
+
18
+
19
+ TITLE = "# [MasaCtrl](https://ljzycmd.github.io/projects/MasaCtrl/)"
20
+ DESCRIPTION = "<b>Gradio demo for MasaCtrl</b>: [[GitHub]](https://github.com/TencentARC/MasaCtrl), \
21
+ [[Paper]](https://arxiv.org/abs/2304.08465). \
22
+ If MasaCtrl is helpful, please help to ⭐ the [Github Repo](https://github.com/TencentARC/MasaCtrl) 😊 </p>"
23
+
24
+ DESCRIPTION += '<p>For faster inference without waiting in queue, \
25
+ you may duplicate the space and upgrade to GPU in settings. </p>'
26
+
27
+
28
+ with gr.Blocks(css="style.css") as demo:
29
+ gr.Markdown(TITLE)
30
+ gr.Markdown(DESCRIPTION)
31
+ model_path_gr = gr.Dropdown(
32
+ ["andite/anything-v4.0",
33
+ "CompVis/stable-diffusion-v1-4",
34
+ "runwayml/stable-diffusion-v1-5"],
35
+ value="andite/anything-v4.0",
36
+ label="Model", info="Select the model to use!"
37
+ )
38
+ with gr.Tab("Consistent Synthesis"):
39
+ create_demo_synthesis()
40
+ with gr.Tab("Real Editing"):
41
+ create_demo_editing()
42
+
43
+ def reload_ckpt(model_path):
44
+ print("Reloading model from", model_path)
45
+ global_context["model"] = MasaCtrlPipeline.from_pretrained(
46
+ model_path, scheduler=global_context["scheduler"]).to(global_context["device"])
47
+
48
+ model_path_gr.select(
49
+ reload_ckpt,
50
+ [model_path_gr]
51
+ )
52
+
53
+
54
+ if __name__ == "__main__":
55
+ demo.launch()
gradio_app/app_utils.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ from diffusers import DDIMScheduler
5
+ from pytorch_lightning import seed_everything
6
+
7
+ from masactrl.diffuser_utils import MasaCtrlPipeline
8
+ from masactrl.masactrl_utils import (AttentionBase,
9
+ regiter_attention_editor_diffusers)
10
+
11
+
12
+ torch.set_grad_enabled(False)
13
+
14
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
15
+ "cpu")
16
+ model_path = "andite/anything-v4.0"
17
+ scheduler = DDIMScheduler(beta_start=0.00085,
18
+ beta_end=0.012,
19
+ beta_schedule="scaled_linear",
20
+ clip_sample=False,
21
+ set_alpha_to_one=False)
22
+ model = MasaCtrlPipeline.from_pretrained(model_path,
23
+ scheduler=scheduler).to(device)
24
+
25
+ global_context = {
26
+ "model_path": model_path,
27
+ "scheduler": scheduler,
28
+ "model": model,
29
+ "device": device
30
+ }
gradio_app/image_synthesis_app.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ from diffusers import DDIMScheduler
5
+ from pytorch_lightning import seed_everything
6
+
7
+ from masactrl.diffuser_utils import MasaCtrlPipeline
8
+ from masactrl.masactrl_utils import (AttentionBase,
9
+ regiter_attention_editor_diffusers)
10
+
11
+ from .app_utils import global_context
12
+
13
+ torch.set_grad_enabled(False)
14
+
15
+ # device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
16
+ # "cpu")
17
+ # model_path = "andite/anything-v4.0"
18
+ # scheduler = DDIMScheduler(beta_start=0.00085,
19
+ # beta_end=0.012,
20
+ # beta_schedule="scaled_linear",
21
+ # clip_sample=False,
22
+ # set_alpha_to_one=False)
23
+ # model = MasaCtrlPipeline.from_pretrained(model_path,
24
+ # scheduler=scheduler).to(device)
25
+
26
+
27
+ def consistent_synthesis(source_prompt, target_prompt, starting_step,
28
+ starting_layer, image_resolution, ddim_steps, scale,
29
+ seed, appended_prompt, negative_prompt):
30
+ from masactrl.masactrl import MutualSelfAttentionControl
31
+
32
+ model = global_context["model"]
33
+ device = global_context["device"]
34
+
35
+ seed_everything(seed)
36
+
37
+ with torch.no_grad():
38
+ if appended_prompt is not None:
39
+ source_prompt += appended_prompt
40
+ target_prompt += appended_prompt
41
+ prompts = [source_prompt, target_prompt]
42
+
43
+ # initialize the noise map
44
+ start_code = torch.randn([1, 4, 64, 64], device=device)
45
+ start_code = start_code.expand(len(prompts), -1, -1, -1)
46
+
47
+ # inference the synthesized image without MasaCtrl
48
+ editor = AttentionBase()
49
+ regiter_attention_editor_diffusers(model, editor)
50
+ target_image_ori = model([target_prompt],
51
+ latents=start_code[-1:],
52
+ guidance_scale=7.5)
53
+ target_image_ori = target_image_ori.cpu().permute(0, 2, 3, 1).numpy()
54
+
55
+ # inference the synthesized image with MasaCtrl
56
+ # hijack the attention module
57
+ controller = MutualSelfAttentionControl(starting_step, starting_layer)
58
+ regiter_attention_editor_diffusers(model, controller)
59
+
60
+ # inference the synthesized image
61
+ image_masactrl = model(prompts, latents=start_code, guidance_scale=7.5)
62
+ image_masactrl = image_masactrl.cpu().permute(0, 2, 3, 1).numpy()
63
+
64
+ return [image_masactrl[0], target_image_ori[0],
65
+ image_masactrl[1]] # source, fixed seed, masactrl
66
+
67
+
68
+ def create_demo_synthesis():
69
+ with gr.Blocks() as demo:
70
+ gr.Markdown("## **Input Settings**")
71
+ with gr.Row():
72
+ with gr.Column():
73
+ source_prompt = gr.Textbox(
74
+ label="Source Prompt",
75
+ value='1boy, casual, outdoors, sitting',
76
+ interactive=True)
77
+ target_prompt = gr.Textbox(
78
+ label="Target Prompt",
79
+ value='1boy, casual, outdoors, standing',
80
+ interactive=True)
81
+ with gr.Row():
82
+ ddim_steps = gr.Slider(label="DDIM Steps",
83
+ minimum=1,
84
+ maximum=999,
85
+ value=50,
86
+ step=1)
87
+ starting_step = gr.Slider(
88
+ label="Step of MasaCtrl",
89
+ minimum=0,
90
+ maximum=999,
91
+ value=4,
92
+ step=1)
93
+ starting_layer = gr.Slider(label="Layer of MasaCtrl",
94
+ minimum=0,
95
+ maximum=16,
96
+ value=10,
97
+ step=1)
98
+ run_btn = gr.Button(label="Run")
99
+ with gr.Column():
100
+ appended_prompt = gr.Textbox(label="Appended Prompt", value='')
101
+ negative_prompt = gr.Textbox(label="Negative Prompt", value='')
102
+ with gr.Row():
103
+ image_resolution = gr.Slider(label="Image Resolution",
104
+ minimum=256,
105
+ maximum=768,
106
+ value=512,
107
+ step=64)
108
+ scale = gr.Slider(label="CFG Scale",
109
+ minimum=0.1,
110
+ maximum=30.0,
111
+ value=7.5,
112
+ step=0.1)
113
+ seed = gr.Slider(label="Seed",
114
+ minimum=-1,
115
+ maximum=2147483647,
116
+ value=42,
117
+ step=1)
118
+
119
+ gr.Markdown("## **Output**")
120
+ with gr.Row():
121
+ image_source = gr.Image(label="Source Image")
122
+ image_fixed = gr.Image(label="Image with Fixed Seed")
123
+ image_masactrl = gr.Image(label="Image with MasaCtrl")
124
+
125
+ inputs = [
126
+ source_prompt, target_prompt, starting_step, starting_layer,
127
+ image_resolution, ddim_steps, scale, seed, appended_prompt,
128
+ negative_prompt
129
+ ]
130
+ run_btn.click(consistent_synthesis, inputs,
131
+ [image_source, image_fixed, image_masactrl])
132
+
133
+ gr.Examples(
134
+ [[
135
+ "1boy, bishounen, casual, indoors, sitting, coffee shop, bokeh",
136
+ "1boy, bishounen, casual, indoors, standing, coffee shop, bokeh",
137
+ 42
138
+ ],
139
+ [
140
+ "1boy, casual, outdoors, sitting",
141
+ "1boy, casual, outdoors, sitting, side view", 42
142
+ ],
143
+ [
144
+ "1boy, casual, outdoors, sitting",
145
+ "1boy, casual, outdoors, standing, clapping hands", 42
146
+ ],
147
+ [
148
+ "1boy, casual, outdoors, sitting",
149
+ "1boy, casual, outdoors, sitting, shows thumbs up", 42
150
+ ],
151
+ [
152
+ "1boy, casual, outdoors, sitting",
153
+ "1boy, casual, outdoors, sitting, with crossed arms", 42
154
+ ],
155
+ [
156
+ "1boy, casual, outdoors, sitting",
157
+ "1boy, casual, outdoors, sitting, rasing hands", 42
158
+ ]],
159
+ [source_prompt, target_prompt, seed],
160
+ )
161
+ return demo
162
+
163
+
164
+ if __name__ == "__main__":
165
+ demo_syntehsis = create_demo_synthesis()
166
+ demo_synthesis.launch()
gradio_app/images/corgi.jpg ADDED
gradio_app/images/person.png ADDED
gradio_app/real_image_editing_app.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import gradio as gr
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from diffusers import DDIMScheduler
7
+ from torchvision.io import read_image
8
+ from pytorch_lightning import seed_everything
9
+
10
+ from masactrl.diffuser_utils import MasaCtrlPipeline
11
+ from masactrl.masactrl_utils import (AttentionBase,
12
+ regiter_attention_editor_diffusers)
13
+
14
+ from .app_utils import global_context
15
+
16
+ torch.set_grad_enabled(False)
17
+
18
+ # device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
19
+ # "cpu")
20
+
21
+ # model_path = "CompVis/stable-diffusion-v1-4"
22
+ # scheduler = DDIMScheduler(beta_start=0.00085,
23
+ # beta_end=0.012,
24
+ # beta_schedule="scaled_linear",
25
+ # clip_sample=False,
26
+ # set_alpha_to_one=False)
27
+ # model = MasaCtrlPipeline.from_pretrained(model_path,
28
+ # scheduler=scheduler).to(device)
29
+
30
+
31
+ def load_image(image_path):
32
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
33
+ image = read_image(image_path)
34
+ image = image[:3].unsqueeze_(0).float() / 127.5 - 1. # [-1, 1]
35
+ image = F.interpolate(image, (512, 512))
36
+ image = image.to(device)
37
+
38
+
39
+ def real_image_editing(source_image, target_prompt,
40
+ starting_step, starting_layer, ddim_steps, scale, seed,
41
+ appended_prompt, negative_prompt):
42
+ from masactrl.masactrl import MutualSelfAttentionControl
43
+
44
+ model = global_context["model"]
45
+ device = global_context["device"]
46
+
47
+ seed_everything(seed)
48
+
49
+ with torch.no_grad():
50
+ if appended_prompt is not None:
51
+ target_prompt += appended_prompt
52
+ ref_prompt = ""
53
+ prompts = [ref_prompt, target_prompt]
54
+
55
+ # invert the image into noise map
56
+ if isinstance(source_image, np.ndarray):
57
+ source_image = torch.from_numpy(source_image).to(device) / 127.5 - 1.
58
+ source_image = source_image.unsqueeze(0).permute(0, 3, 1, 2)
59
+ source_image = F.interpolate(source_image, (512, 512))
60
+
61
+ start_code, latents_list = model.invert(source_image,
62
+ ref_prompt,
63
+ guidance_scale=scale,
64
+ num_inference_steps=ddim_steps,
65
+ return_intermediates=True)
66
+ start_code = start_code.expand(len(prompts), -1, -1, -1)
67
+
68
+ # recontruct the image with inverted DDIM noise map
69
+ editor = AttentionBase()
70
+ regiter_attention_editor_diffusers(model, editor)
71
+ image_fixed = model([target_prompt],
72
+ latents=start_code[-1:],
73
+ num_inference_steps=ddim_steps,
74
+ guidance_scale=scale)
75
+ image_fixed = image_fixed.cpu().permute(0, 2, 3, 1).numpy()
76
+
77
+ # inference the synthesized image with MasaCtrl
78
+ # hijack the attention module
79
+ controller = MutualSelfAttentionControl(starting_step, starting_layer)
80
+ regiter_attention_editor_diffusers(model, controller)
81
+
82
+ # inference the synthesized image
83
+ image_masactrl = model(prompts,
84
+ latents=start_code,
85
+ guidance_scale=scale)
86
+ image_masactrl = image_masactrl.cpu().permute(0, 2, 3, 1).numpy()
87
+
88
+ return [
89
+ image_masactrl[0],
90
+ image_fixed[0],
91
+ image_masactrl[1]
92
+ ] # source, fixed seed, masactrl
93
+
94
+
95
+ def create_demo_editing():
96
+ with gr.Blocks() as demo:
97
+ gr.Markdown("## **Input Settings**")
98
+ with gr.Row():
99
+ with gr.Column():
100
+ source_image = gr.Image(label="Source Image", value=os.path.join(os.path.dirname(__file__), "images/corgi.jpg"), interactive=True)
101
+ target_prompt = gr.Textbox(label="Target Prompt",
102
+ value='A photo of a running corgi',
103
+ interactive=True)
104
+ with gr.Row():
105
+ ddim_steps = gr.Slider(label="DDIM Steps",
106
+ minimum=1,
107
+ maximum=999,
108
+ value=50,
109
+ step=1)
110
+ starting_step = gr.Slider(label="Step of MasaCtrl",
111
+ minimum=0,
112
+ maximum=999,
113
+ value=4,
114
+ step=1)
115
+ starting_layer = gr.Slider(label="Layer of MasaCtrl",
116
+ minimum=0,
117
+ maximum=16,
118
+ value=10,
119
+ step=1)
120
+ run_btn = gr.Button(label="Run")
121
+ with gr.Column():
122
+ appended_prompt = gr.Textbox(label="Appended Prompt", value='')
123
+ negative_prompt = gr.Textbox(label="Negative Prompt", value='')
124
+ with gr.Row():
125
+ scale = gr.Slider(label="CFG Scale",
126
+ minimum=0.1,
127
+ maximum=30.0,
128
+ value=7.5,
129
+ step=0.1)
130
+ seed = gr.Slider(label="Seed",
131
+ minimum=-1,
132
+ maximum=2147483647,
133
+ value=42,
134
+ step=1)
135
+
136
+ gr.Markdown("## **Output**")
137
+ with gr.Row():
138
+ image_recons = gr.Image(label="Source Image")
139
+ image_fixed = gr.Image(label="Image with Fixed Seed")
140
+ image_masactrl = gr.Image(label="Image with MasaCtrl")
141
+
142
+ inputs = [
143
+ source_image, target_prompt, starting_step, starting_layer, ddim_steps,
144
+ scale, seed, appended_prompt, negative_prompt
145
+ ]
146
+ run_btn.click(real_image_editing, inputs,
147
+ [image_recons, image_fixed, image_masactrl])
148
+
149
+ gr.Examples(
150
+ [[os.path.join(os.path.dirname(__file__), "images/corgi.jpg"),
151
+ "A photo of a running corgi"],
152
+ [os.path.join(os.path.dirname(__file__), "images/person.png"),
153
+ "A photo of a person, black t-shirt, raising hand"],
154
+ ],
155
+ [source_image, target_prompt]
156
+ )
157
+ return demo
158
+
159
+
160
+ if __name__ == "__main__":
161
+ demo_editing = create_demo_editing()
162
+ demo_editing.launch()
masactrl/__init__.py ADDED
File without changes
masactrl/diffuser_utils.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Util functions based on Diffuser framework.
3
+ """
4
+
5
+
6
+ import os
7
+ import torch
8
+ import cv2
9
+ import numpy as np
10
+
11
+ import torch.nn.functional as F
12
+ from tqdm import tqdm
13
+ from PIL import Image
14
+ from torchvision.utils import save_image
15
+ from torchvision.io import read_image
16
+
17
+ from diffusers import StableDiffusionPipeline
18
+
19
+ from pytorch_lightning import seed_everything
20
+
21
+
22
+ class MasaCtrlPipeline(StableDiffusionPipeline):
23
+
24
+ def next_step(
25
+ self,
26
+ model_output: torch.FloatTensor,
27
+ timestep: int,
28
+ x: torch.FloatTensor,
29
+ eta=0.,
30
+ verbose=False
31
+ ):
32
+ """
33
+ Inverse sampling for DDIM Inversion
34
+ """
35
+ if verbose:
36
+ print("timestep: ", timestep)
37
+ next_step = timestep
38
+ timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999)
39
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
40
+ alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]
41
+ beta_prod_t = 1 - alpha_prod_t
42
+ pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
43
+ pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output
44
+ x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir
45
+ return x_next, pred_x0
46
+
47
+ def step(
48
+ self,
49
+ model_output: torch.FloatTensor,
50
+ timestep: int,
51
+ x: torch.FloatTensor,
52
+ eta: float=0.0,
53
+ verbose=False,
54
+ ):
55
+ """
56
+ predict the sampe the next step in the denoise process.
57
+ """
58
+ prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
59
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
60
+ alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep > 0 else self.scheduler.final_alpha_cumprod
61
+ beta_prod_t = 1 - alpha_prod_t
62
+ pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
63
+ pred_dir = (1 - alpha_prod_t_prev)**0.5 * model_output
64
+ x_prev = alpha_prod_t_prev**0.5 * pred_x0 + pred_dir
65
+ return x_prev, pred_x0
66
+
67
+ @torch.no_grad()
68
+ def image2latent(self, image):
69
+ DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
70
+ if type(image) is Image:
71
+ image = np.array(image)
72
+ image = torch.from_numpy(image).float() / 127.5 - 1
73
+ image = image.permute(2, 0, 1).unsqueeze(0).to(DEVICE)
74
+ # input image density range [-1, 1]
75
+ latents = self.vae.encode(image)['latent_dist'].mean
76
+ latents = latents * 0.18215
77
+ return latents
78
+
79
+ @torch.no_grad()
80
+ def latent2image(self, latents, return_type='np'):
81
+ latents = 1 / 0.18215 * latents.detach()
82
+ image = self.vae.decode(latents)['sample']
83
+ if return_type == 'np':
84
+ image = (image / 2 + 0.5).clamp(0, 1)
85
+ image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
86
+ image = (image * 255).astype(np.uint8)
87
+ elif return_type == "pt":
88
+ image = (image / 2 + 0.5).clamp(0, 1)
89
+
90
+ return image
91
+
92
+ def latent2image_grad(self, latents):
93
+ latents = 1 / 0.18215 * latents
94
+ image = self.vae.decode(latents)['sample']
95
+
96
+ return image # range [-1, 1]
97
+
98
+ @torch.no_grad()
99
+ def __call__(
100
+ self,
101
+ prompt,
102
+ batch_size=1,
103
+ height=512,
104
+ width=512,
105
+ num_inference_steps=50,
106
+ guidance_scale=7.5,
107
+ eta=0.0,
108
+ latents=None,
109
+ unconditioning=None,
110
+ neg_prompt=None,
111
+ ref_intermediate_latents=None,
112
+ return_intermediates=False,
113
+ **kwds):
114
+ DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
115
+ if isinstance(prompt, list):
116
+ batch_size = len(prompt)
117
+ elif isinstance(prompt, str):
118
+ if batch_size > 1:
119
+ prompt = [prompt] * batch_size
120
+
121
+ # text embeddings
122
+ text_input = self.tokenizer(
123
+ prompt,
124
+ padding="max_length",
125
+ max_length=77,
126
+ return_tensors="pt"
127
+ )
128
+
129
+ text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0]
130
+ print("input text embeddings :", text_embeddings.shape)
131
+ if kwds.get("dir"):
132
+ dir = text_embeddings[-2] - text_embeddings[-1]
133
+ u, s, v = torch.pca_lowrank(dir.transpose(-1, -2), q=1, center=True)
134
+ text_embeddings[-1] = text_embeddings[-1] + kwds.get("dir") * v
135
+ print(u.shape)
136
+ print(v.shape)
137
+
138
+ # define initial latents
139
+ latents_shape = (batch_size, self.unet.in_channels, height//8, width//8)
140
+ if latents is None:
141
+ latents = torch.randn(latents_shape, device=DEVICE)
142
+ else:
143
+ assert latents.shape == latents_shape, f"The shape of input latent tensor {latents.shape} should equal to predefined one."
144
+
145
+ # unconditional embedding for classifier free guidance
146
+ if guidance_scale > 1.:
147
+ max_length = text_input.input_ids.shape[-1]
148
+ if neg_prompt:
149
+ uc_text = neg_prompt
150
+ else:
151
+ uc_text = ""
152
+ # uc_text = "ugly, tiling, poorly drawn hands, poorly drawn feet, body out of frame, cut off, low contrast, underexposed, distorted face"
153
+ unconditional_input = self.tokenizer(
154
+ [uc_text] * batch_size,
155
+ padding="max_length",
156
+ max_length=77,
157
+ return_tensors="pt"
158
+ )
159
+ # unconditional_input.input_ids = unconditional_input.input_ids[:, 1:]
160
+ unconditional_embeddings = self.text_encoder(unconditional_input.input_ids.to(DEVICE))[0]
161
+ text_embeddings = torch.cat([unconditional_embeddings, text_embeddings], dim=0)
162
+
163
+ print("latents shape: ", latents.shape)
164
+ # iterative sampling
165
+ self.scheduler.set_timesteps(num_inference_steps)
166
+ # print("Valid timesteps: ", reversed(self.scheduler.timesteps))
167
+ latents_list = [latents]
168
+ pred_x0_list = [latents]
169
+ for i, t in enumerate(tqdm(self.scheduler.timesteps, desc="DDIM Sampler")):
170
+ if ref_intermediate_latents is not None:
171
+ # note that the batch_size >= 2
172
+ latents_ref = ref_intermediate_latents[-1 - i]
173
+ _, latents_cur = latents.chunk(2)
174
+ latents = torch.cat([latents_ref, latents_cur])
175
+
176
+ if guidance_scale > 1.:
177
+ model_inputs = torch.cat([latents] * 2)
178
+ else:
179
+ model_inputs = latents
180
+ if unconditioning is not None and isinstance(unconditioning, list):
181
+ _, text_embeddings = text_embeddings.chunk(2)
182
+ text_embeddings = torch.cat([unconditioning[i].expand(*text_embeddings.shape), text_embeddings])
183
+ # predict tghe noise
184
+ noise_pred = self.unet(model_inputs, t, encoder_hidden_states=text_embeddings).sample
185
+ if guidance_scale > 1.:
186
+ noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
187
+ noise_pred = noise_pred_uncon + guidance_scale * (noise_pred_con - noise_pred_uncon)
188
+ # compute the previous noise sample x_t -> x_t-1
189
+ latents, pred_x0 = self.step(noise_pred, t, latents)
190
+ latents_list.append(latents)
191
+ pred_x0_list.append(pred_x0)
192
+
193
+ image = self.latent2image(latents, return_type="pt")
194
+ if return_intermediates:
195
+ pred_x0_list = [self.latent2image(img, return_type="pt") for img in pred_x0_list]
196
+ latents_list = [self.latent2image(img, return_type="pt") for img in latents_list]
197
+ return image, pred_x0_list, latents_list
198
+ return image
199
+
200
+ @torch.no_grad()
201
+ def invert(
202
+ self,
203
+ image: torch.Tensor,
204
+ prompt,
205
+ num_inference_steps=50,
206
+ guidance_scale=7.5,
207
+ eta=0.0,
208
+ return_intermediates=False,
209
+ **kwds):
210
+ """
211
+ invert a real image into noise map with determinisc DDIM inversion
212
+ """
213
+ DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
214
+ batch_size = image.shape[0]
215
+ if isinstance(prompt, list):
216
+ if batch_size == 1:
217
+ image = image.expand(len(prompt), -1, -1, -1)
218
+ elif isinstance(prompt, str):
219
+ if batch_size > 1:
220
+ prompt = [prompt] * batch_size
221
+
222
+ # text embeddings
223
+ text_input = self.tokenizer(
224
+ prompt,
225
+ padding="max_length",
226
+ max_length=77,
227
+ return_tensors="pt"
228
+ )
229
+ text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0]
230
+ print("input text embeddings :", text_embeddings.shape)
231
+ # define initial latents
232
+ latents = self.image2latent(image)
233
+ start_latents = latents
234
+ # print(latents)
235
+ # exit()
236
+ # unconditional embedding for classifier free guidance
237
+ if guidance_scale > 1.:
238
+ max_length = text_input.input_ids.shape[-1]
239
+ unconditional_input = self.tokenizer(
240
+ [""] * batch_size,
241
+ padding="max_length",
242
+ max_length=77,
243
+ return_tensors="pt"
244
+ )
245
+ unconditional_embeddings = self.text_encoder(unconditional_input.input_ids.to(DEVICE))[0]
246
+ text_embeddings = torch.cat([unconditional_embeddings, text_embeddings], dim=0)
247
+
248
+ print("latents shape: ", latents.shape)
249
+ # interative sampling
250
+ self.scheduler.set_timesteps(num_inference_steps)
251
+ print("Valid timesteps: ", reversed(self.scheduler.timesteps))
252
+ # print("attributes: ", self.scheduler.__dict__)
253
+ latents_list = [latents]
254
+ pred_x0_list = [latents]
255
+ for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")):
256
+ if guidance_scale > 1.:
257
+ model_inputs = torch.cat([latents] * 2)
258
+ else:
259
+ model_inputs = latents
260
+
261
+ # predict the noise
262
+ noise_pred = self.unet(model_inputs, t, encoder_hidden_states=text_embeddings).sample
263
+ if guidance_scale > 1.:
264
+ noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
265
+ noise_pred = noise_pred_uncon + guidance_scale * (noise_pred_con - noise_pred_uncon)
266
+ # compute the previous noise sample x_t-1 -> x_t
267
+ latents, pred_x0 = self.next_step(noise_pred, t, latents)
268
+ latents_list.append(latents)
269
+ pred_x0_list.append(pred_x0)
270
+
271
+ if return_intermediates:
272
+ # return the intermediate laters during inversion
273
+ # pred_x0_list = [self.latent2image(img, return_type="pt") for img in pred_x0_list]
274
+ return latents, latents_list
275
+ return latents, start_latents
masactrl/masactrl.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+
7
+ from einops import rearrange
8
+
9
+ from .masactrl_utils import AttentionBase
10
+
11
+ from torchvision.utils import save_image
12
+
13
+
14
+ class MutualSelfAttentionControl(AttentionBase):
15
+ def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50):
16
+ """
17
+ Mutual self-attention control for Stable-Diffusion model
18
+ Args:
19
+ start_step: the step to start mutual self-attention control
20
+ start_layer: the layer to start mutual self-attention control
21
+ layer_idx: list of the layers to apply mutual self-attention control
22
+ step_idx: list the steps to apply mutual self-attention control
23
+ total_steps: the total number of steps
24
+ """
25
+ super().__init__()
26
+ self.total_steps = total_steps
27
+ self.start_step = start_step
28
+ self.start_layer = start_layer
29
+ self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, 16))
30
+ self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps))
31
+ print("step_idx: ", self.step_idx)
32
+ print("layer_idx: ", self.layer_idx)
33
+
34
+ def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
35
+ b = q.shape[0] // num_heads
36
+ q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads)
37
+ k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads)
38
+ v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads)
39
+
40
+ sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale")
41
+ attn = sim.softmax(-1)
42
+ out = torch.einsum("h i j, h j d -> h i d", attn, v)
43
+ out = rearrange(out, "h (b n) d -> b n (h d)", b=b)
44
+ return out
45
+
46
+ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
47
+ """
48
+ Attention forward function
49
+ """
50
+ if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
51
+ return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
52
+
53
+ qu, qc = q.chunk(2)
54
+ ku, kc = k.chunk(2)
55
+ vu, vc = v.chunk(2)
56
+ attnu, attnc = attn.chunk(2)
57
+
58
+ out_u = self.attn_batch(qu, ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs)
59
+ out_c = self.attn_batch(qc, kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs)
60
+ out = torch.cat([out_u, out_c], dim=0)
61
+
62
+ return out
63
+
64
+
65
+ class MutualSelfAttentionControlMask(MutualSelfAttentionControl):
66
+ def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, mask_s=None, mask_t=None, mask_save_dir=None):
67
+ """
68
+ Maske-guided MasaCtrl to alleviate the problem of fore- and background confusion
69
+ Args:
70
+ start_step: the step to start mutual self-attention control
71
+ start_layer: the layer to start mutual self-attention control
72
+ layer_idx: list of the layers to apply mutual self-attention control
73
+ step_idx: list the steps to apply mutual self-attention control
74
+ total_steps: the total number of steps
75
+ mask_s: source mask with shape (h, w)
76
+ mask_t: target mask with same shape as source mask
77
+ """
78
+ super().__init__(start_step, start_layer, layer_idx, step_idx, total_steps)
79
+ self.mask_s = mask_s # source mask with shape (h, w)
80
+ self.mask_t = mask_t # target mask with same shape as source mask
81
+ print("Using mask-guided MasaCtrl")
82
+ if mask_save_dir is not None:
83
+ os.makedirs(mask_save_dir, exist_ok=True)
84
+ save_image(self.mask_s.unsqueeze(0).unsqueeze(0), os.path.join(mask_save_dir, "mask_s.png"))
85
+ save_image(self.mask_t.unsqueeze(0).unsqueeze(0), os.path.join(mask_save_dir, "mask_t.png"))
86
+
87
+ def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
88
+ B = q.shape[0] // num_heads
89
+ H = W = int(np.sqrt(q.shape[1]))
90
+ q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads)
91
+ k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads)
92
+ v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads)
93
+
94
+ sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale")
95
+ if kwargs.get("is_mask_attn") and self.mask_s is not None:
96
+ print("masked attention")
97
+ mask = self.mask_s.unsqueeze(0).unsqueeze(0)
98
+ mask = F.interpolate(mask, (H, W)).flatten(0).unsqueeze(0)
99
+ mask = mask.flatten()
100
+ # background
101
+ sim_bg = sim + mask.masked_fill(mask == 1, torch.finfo(sim.dtype).min)
102
+ # object
103
+ sim_fg = sim + mask.masked_fill(mask == 0, torch.finfo(sim.dtype).min)
104
+ sim = torch.cat([sim_fg, sim_bg], dim=0)
105
+ attn = sim.softmax(-1)
106
+ if len(attn) == 2 * len(v):
107
+ v = torch.cat([v] * 2)
108
+ out = torch.einsum("h i j, h j d -> h i d", attn, v)
109
+ out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads)
110
+ return out
111
+
112
+ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
113
+ """
114
+ Attention forward function
115
+ """
116
+ if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
117
+ return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
118
+
119
+ B = q.shape[0] // num_heads // 2
120
+ H = W = int(np.sqrt(q.shape[1]))
121
+ qu, qc = q.chunk(2)
122
+ ku, kc = k.chunk(2)
123
+ vu, vc = v.chunk(2)
124
+ attnu, attnc = attn.chunk(2)
125
+
126
+ out_u_source = self.attn_batch(qu[:num_heads], ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs)
127
+ out_c_source = self.attn_batch(qc[:num_heads], kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs)
128
+
129
+ out_u_target = self.attn_batch(qu[-num_heads:], ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, is_mask_attn=True, **kwargs)
130
+ out_c_target = self.attn_batch(qc[-num_heads:], kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, is_mask_attn=True, **kwargs)
131
+
132
+ if self.mask_s is not None and self.mask_t is not None:
133
+ out_u_target_fg, out_u_target_bg = out_u_target.chunk(2, 0)
134
+ out_c_target_fg, out_c_target_bg = out_c_target.chunk(2, 0)
135
+
136
+ mask = F.interpolate(self.mask_t.unsqueeze(0).unsqueeze(0), (H, W))
137
+ mask = mask.reshape(-1, 1) # (hw, 1)
138
+ out_u_target = out_u_target_fg * mask + out_u_target_bg * (1 - mask)
139
+ out_c_target = out_c_target_fg * mask + out_c_target_bg * (1 - mask)
140
+
141
+ out = torch.cat([out_u_source, out_u_target, out_c_source, out_c_target], dim=0)
142
+ return out
143
+
144
+
145
+ class MutualSelfAttentionControlMaskAuto(MutualSelfAttentionControl):
146
+ def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, thres=0.1, ref_token_idx=[1], cur_token_idx=[1], mask_save_dir=None):
147
+ """
148
+ MasaCtrl with mask auto generation from cross-attention map
149
+ Args:
150
+ start_step: the step to start mutual self-attention control
151
+ start_layer: the layer to start mutual self-attention control
152
+ layer_idx: list of the layers to apply mutual self-attention control
153
+ step_idx: list the steps to apply mutual self-attention control
154
+ total_steps: the total number of steps
155
+ thres: the thereshold for mask thresholding
156
+ ref_token_idx: the token index list for cross-attention map aggregation
157
+ cur_token_idx: the token index list for cross-attention map aggregation
158
+ mask_save_dir: the path to save the mask image
159
+ """
160
+ super().__init__(start_step, start_layer, layer_idx, step_idx, total_steps)
161
+ print("using MutualSelfAttentionControlMaskAuto")
162
+ self.thres = thres
163
+ self.ref_token_idx = ref_token_idx
164
+ self.cur_token_idx = cur_token_idx
165
+
166
+ self.self_attns = []
167
+ self.cross_attns = []
168
+
169
+ self.cross_attns_mask = None
170
+ self.self_attns_mask = None
171
+
172
+ self.mask_save_dir = mask_save_dir
173
+ if self.mask_save_dir is not None:
174
+ os.makedirs(self.mask_save_dir, exist_ok=True)
175
+
176
+ def after_step(self):
177
+ self.self_attns = []
178
+ self.cross_attns = []
179
+
180
+ def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
181
+ B = q.shape[0] // num_heads
182
+ H = W = int(np.sqrt(q.shape[1]))
183
+ q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads)
184
+ k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads)
185
+ v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads)
186
+
187
+ sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale")
188
+ if self.self_attns_mask is not None:
189
+ # binarize the mask
190
+ mask = self.self_attns_mask
191
+ thres = self.thres
192
+ mask[mask >= thres] = 1
193
+ mask[mask < thres] = 0
194
+ sim_fg = sim + mask.masked_fill(mask == 0, torch.finfo(sim.dtype).min)
195
+ sim_bg = sim + mask.masked_fill(mask == 1, torch.finfo(sim.dtype).min)
196
+ sim = torch.cat([sim_fg, sim_bg])
197
+
198
+ attn = sim.softmax(-1)
199
+
200
+ if len(attn) == 2 * len(v):
201
+ v = torch.cat([v] * 2)
202
+ out = torch.einsum("h i j, h j d -> h i d", attn, v)
203
+ out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads)
204
+ return out
205
+
206
+ def aggregate_cross_attn_map(self, idx):
207
+ attn_map = torch.stack(self.cross_attns, dim=1).mean(1) # (B, N, dim)
208
+ B = attn_map.shape[0]
209
+ res = int(np.sqrt(attn_map.shape[-2]))
210
+ attn_map = attn_map.reshape(-1, res, res, attn_map.shape[-1])
211
+ image = attn_map[..., idx]
212
+ if isinstance(idx, list):
213
+ image = image.sum(-1)
214
+ image_min = image.min(dim=1, keepdim=True)[0].min(dim=2, keepdim=True)[0]
215
+ image_max = image.max(dim=1, keepdim=True)[0].max(dim=2, keepdim=True)[0]
216
+ image = (image - image_min) / (image_max - image_min)
217
+ return image
218
+
219
+ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
220
+ """
221
+ Attention forward function
222
+ """
223
+ if is_cross:
224
+ # save cross attention map with res 16 * 16
225
+ if attn.shape[1] == 16 * 16:
226
+ self.cross_attns.append(attn.reshape(-1, num_heads, *attn.shape[-2:]).mean(1))
227
+
228
+ if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
229
+ return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
230
+
231
+ B = q.shape[0] // num_heads // 2
232
+ H = W = int(np.sqrt(q.shape[1]))
233
+ qu, qc = q.chunk(2)
234
+ ku, kc = k.chunk(2)
235
+ vu, vc = v.chunk(2)
236
+ attnu, attnc = attn.chunk(2)
237
+
238
+ out_u_source = self.attn_batch(qu[:num_heads], ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs)
239
+ out_c_source = self.attn_batch(qc[:num_heads], kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs)
240
+
241
+ if len(self.cross_attns) == 0:
242
+ self.self_attns_mask = None
243
+ out_u_target = self.attn_batch(qu[-num_heads:], ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs)
244
+ out_c_target = self.attn_batch(qc[-num_heads:], kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs)
245
+ else:
246
+ mask = self.aggregate_cross_attn_map(idx=self.ref_token_idx) # (2, H, W)
247
+ mask_source = mask[-2] # (H, W)
248
+ res = int(np.sqrt(q.shape[1]))
249
+ self.self_attns_mask = F.interpolate(mask_source.unsqueeze(0).unsqueeze(0), (res, res)).flatten()
250
+ if self.mask_save_dir is not None:
251
+ H = W = int(np.sqrt(self.self_attns_mask.shape[0]))
252
+ mask_image = self.self_attns_mask.reshape(H, W).unsqueeze(0)
253
+ save_image(mask_image, os.path.join(self.mask_save_dir, f"mask_s_{self.cur_step}_{self.cur_att_layer}.png"))
254
+ out_u_target = self.attn_batch(qu[-num_heads:], ku[:num_heads], vu[:num_heads], sim[:num_heads], attnu, is_cross, place_in_unet, num_heads, **kwargs)
255
+ out_c_target = self.attn_batch(qc[-num_heads:], kc[:num_heads], vc[:num_heads], sim[:num_heads], attnc, is_cross, place_in_unet, num_heads, **kwargs)
256
+
257
+ if self.self_attns_mask is not None:
258
+ mask = self.aggregate_cross_attn_map(idx=self.cur_token_idx) # (2, H, W)
259
+ mask_target = mask[-1] # (H, W)
260
+ res = int(np.sqrt(q.shape[1]))
261
+ spatial_mask = F.interpolate(mask_target.unsqueeze(0).unsqueeze(0), (res, res)).reshape(-1, 1)
262
+ if self.mask_save_dir is not None:
263
+ H = W = int(np.sqrt(spatial_mask.shape[0]))
264
+ mask_image = spatial_mask.reshape(H, W).unsqueeze(0)
265
+ save_image(mask_image, os.path.join(self.mask_save_dir, f"mask_t_{self.cur_step}_{self.cur_att_layer}.png"))
266
+ # binarize the mask
267
+ thres = self.thres
268
+ spatial_mask[spatial_mask >= thres] = 1
269
+ spatial_mask[spatial_mask < thres] = 0
270
+ out_u_target_fg, out_u_target_bg = out_u_target.chunk(2)
271
+ out_c_target_fg, out_c_target_bg = out_c_target.chunk(2)
272
+
273
+ out_u_target = out_u_target_fg * spatial_mask + out_u_target_bg * (1 - spatial_mask)
274
+ out_c_target = out_c_target_fg * spatial_mask + out_c_target_bg * (1 - spatial_mask)
275
+
276
+ # set self self-attention mask to None
277
+ self.self_attns_mask = None
278
+
279
+ out = torch.cat([out_u_source, out_u_target, out_c_source, out_c_target], dim=0)
280
+ return out
masactrl/masactrl_utils.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from typing import Optional, Union, Tuple, List, Callable, Dict
9
+
10
+ from torchvision.utils import save_image
11
+ from einops import rearrange, repeat
12
+
13
+
14
+ class AttentionBase:
15
+ def __init__(self):
16
+ self.cur_step = 0
17
+ self.num_att_layers = -1
18
+ self.cur_att_layer = 0
19
+
20
+ def after_step(self):
21
+ pass
22
+
23
+ def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
24
+ out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
25
+ self.cur_att_layer += 1
26
+ if self.cur_att_layer == self.num_att_layers:
27
+ self.cur_att_layer = 0
28
+ self.cur_step += 1
29
+ # after step
30
+ self.after_step()
31
+ return out
32
+
33
+ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
34
+ out = torch.einsum('b i j, b j d -> b i d', attn, v)
35
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
36
+ return out
37
+
38
+ def reset(self):
39
+ self.cur_step = 0
40
+ self.cur_att_layer = 0
41
+
42
+
43
+ class AttentionStore(AttentionBase):
44
+ def __init__(self, res=[32], min_step=0, max_step=1000):
45
+ super().__init__()
46
+ self.res = res
47
+ self.min_step = min_step
48
+ self.max_step = max_step
49
+ self.valid_steps = 0
50
+
51
+ self.self_attns = [] # store the all attns
52
+ self.cross_attns = []
53
+
54
+ self.self_attns_step = [] # store the attns in each step
55
+ self.cross_attns_step = []
56
+
57
+ def after_step(self):
58
+ if self.cur_step > self.min_step and self.cur_step < self.max_step:
59
+ self.valid_steps += 1
60
+ if len(self.self_attns) == 0:
61
+ self.self_attns = self.self_attns_step
62
+ self.cross_attns = self.cross_attns_step
63
+ else:
64
+ for i in range(len(self.self_attns)):
65
+ self.self_attns[i] += self.self_attns_step[i]
66
+ self.cross_attns[i] += self.cross_attns_step[i]
67
+ self.self_attns_step.clear()
68
+ self.cross_attns_step.clear()
69
+
70
+ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
71
+ if attn.shape[1] <= 64 ** 2: # avoid OOM
72
+ if is_cross:
73
+ self.cross_attns_step.append(attn)
74
+ else:
75
+ self.self_attns_step.append(attn)
76
+ return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
77
+
78
+
79
+ def regiter_attention_editor_diffusers(model, editor: AttentionBase):
80
+ """
81
+ Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
82
+ """
83
+ def ca_forward(self, place_in_unet):
84
+ def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):
85
+ """
86
+ The attention is similar to the original implementation of LDM CrossAttention class
87
+ except adding some modifications on the attention
88
+ """
89
+ if encoder_hidden_states is not None:
90
+ context = encoder_hidden_states
91
+ if attention_mask is not None:
92
+ mask = attention_mask
93
+
94
+ to_out = self.to_out
95
+ if isinstance(to_out, nn.modules.container.ModuleList):
96
+ to_out = self.to_out[0]
97
+ else:
98
+ to_out = self.to_out
99
+
100
+ h = self.heads
101
+ q = self.to_q(x)
102
+ is_cross = context is not None
103
+ context = context if is_cross else x
104
+ k = self.to_k(context)
105
+ v = self.to_v(context)
106
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
107
+
108
+ sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
109
+
110
+ if mask is not None:
111
+ mask = rearrange(mask, 'b ... -> b (...)')
112
+ max_neg_value = -torch.finfo(sim.dtype).max
113
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
114
+ mask = mask[:, None, :].repeat(h, 1, 1)
115
+ sim.masked_fill_(~mask, max_neg_value)
116
+
117
+ attn = sim.softmax(dim=-1)
118
+ # the only difference
119
+ out = editor(
120
+ q, k, v, sim, attn, is_cross, place_in_unet,
121
+ self.heads, scale=self.scale)
122
+
123
+ return to_out(out)
124
+
125
+ return forward
126
+
127
+ def register_editor(net, count, place_in_unet):
128
+ for name, subnet in net.named_children():
129
+ if net.__class__.__name__ == 'Attention': # spatial Transformer layer
130
+ net.forward = ca_forward(net, place_in_unet)
131
+ return count + 1
132
+ elif hasattr(net, 'children'):
133
+ count = register_editor(subnet, count, place_in_unet)
134
+ return count
135
+
136
+ cross_att_count = 0
137
+ for net_name, net in model.unet.named_children():
138
+ if "down" in net_name:
139
+ cross_att_count += register_editor(net, 0, "down")
140
+ elif "mid" in net_name:
141
+ cross_att_count += register_editor(net, 0, "mid")
142
+ elif "up" in net_name:
143
+ cross_att_count += register_editor(net, 0, "up")
144
+ editor.num_att_layers = cross_att_count
145
+
146
+
147
+ def regiter_attention_editor_ldm(model, editor: AttentionBase):
148
+ """
149
+ Register a attention editor to Stable Diffusion model, refer from [Prompt-to-Prompt]
150
+ """
151
+ def ca_forward(self, place_in_unet):
152
+ def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):
153
+ """
154
+ The attention is similar to the original implementation of LDM CrossAttention class
155
+ except adding some modifications on the attention
156
+ """
157
+ if encoder_hidden_states is not None:
158
+ context = encoder_hidden_states
159
+ if attention_mask is not None:
160
+ mask = attention_mask
161
+
162
+ to_out = self.to_out
163
+ if isinstance(to_out, nn.modules.container.ModuleList):
164
+ to_out = self.to_out[0]
165
+ else:
166
+ to_out = self.to_out
167
+
168
+ h = self.heads
169
+ q = self.to_q(x)
170
+ is_cross = context is not None
171
+ context = context if is_cross else x
172
+ k = self.to_k(context)
173
+ v = self.to_v(context)
174
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
175
+
176
+ sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
177
+
178
+ if mask is not None:
179
+ mask = rearrange(mask, 'b ... -> b (...)')
180
+ max_neg_value = -torch.finfo(sim.dtype).max
181
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
182
+ mask = mask[:, None, :].repeat(h, 1, 1)
183
+ sim.masked_fill_(~mask, max_neg_value)
184
+
185
+ attn = sim.softmax(dim=-1)
186
+ # the only difference
187
+ out = editor(
188
+ q, k, v, sim, attn, is_cross, place_in_unet,
189
+ self.heads, scale=self.scale)
190
+
191
+ return to_out(out)
192
+
193
+ return forward
194
+
195
+ def register_editor(net, count, place_in_unet):
196
+ for name, subnet in net.named_children():
197
+ if net.__class__.__name__ == 'CrossAttention': # spatial Transformer layer
198
+ net.forward = ca_forward(net, place_in_unet)
199
+ return count + 1
200
+ elif hasattr(net, 'children'):
201
+ count = register_editor(subnet, count, place_in_unet)
202
+ return count
203
+
204
+ cross_att_count = 0
205
+ for net_name, net in model.model.diffusion_model.named_children():
206
+ if "input" in net_name:
207
+ cross_att_count += register_editor(net, 0, "input")
208
+ elif "middle" in net_name:
209
+ cross_att_count += register_editor(net, 0, "middle")
210
+ elif "output" in net_name:
211
+ cross_att_count += register_editor(net, 0, "output")
212
+ editor.num_att_layers = cross_att_count
playground.ipynb ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "### MasaCtrl: Tuning-free Mutual Self-Attention Control for Consistent Image Synthesis and Editing"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": null,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "import os\n",
17
+ "import torch\n",
18
+ "import torch.nn as nn\n",
19
+ "import torch.nn.functional as F\n",
20
+ "\n",
21
+ "import numpy as np\n",
22
+ "\n",
23
+ "from tqdm import tqdm\n",
24
+ "from einops import rearrange, repeat\n",
25
+ "from omegaconf import OmegaConf\n",
26
+ "\n",
27
+ "from diffusers import DDIMScheduler\n",
28
+ "\n",
29
+ "from masactrl.diffuser_utils import MasaCtrlPipeline\n",
30
+ "from masactrl.masactrl_utils import AttentionBase\n",
31
+ "from masactrl.masactrl_utils import regiter_attention_editor_diffusers\n",
32
+ "\n",
33
+ "from torchvision.utils import save_image\n",
34
+ "from torchvision.io import read_image\n",
35
+ "from pytorch_lightning import seed_everything\n",
36
+ "\n",
37
+ "torch.cuda.set_device(6) # set the GPU device"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "markdown",
42
+ "metadata": {},
43
+ "source": [
44
+ "#### Model Construction"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": null,
50
+ "metadata": {},
51
+ "outputs": [],
52
+ "source": [
53
+ "# Note that you may add your Hugging Face token to get access to the models\n",
54
+ "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
55
+ "model_path = \"andite/anything-v4.0\"\n",
56
+ "# model_path = \"runwayml/stable-diffusion-v1-5\"\n",
57
+ "scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", clip_sample=False, set_alpha_to_one=False)\n",
58
+ "model = MasaCtrlPipeline.from_pretrained(model_path, scheduler=scheduler, cross_attention_kwargs={\"scale\": 0.5}).to(device)"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "markdown",
63
+ "metadata": {},
64
+ "source": [
65
+ "#### Consistent synthesis with MasaCtrl"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "execution_count": null,
71
+ "metadata": {},
72
+ "outputs": [],
73
+ "source": [
74
+ "from masactrl.masactrl import MutualSelfAttentionControl\n",
75
+ "\n",
76
+ "\n",
77
+ "seed = 42\n",
78
+ "seed_everything(seed)\n",
79
+ "\n",
80
+ "out_dir = \"./workdir/masactrl_exp/\"\n",
81
+ "os.makedirs(out_dir, exist_ok=True)\n",
82
+ "sample_count = len(os.listdir(out_dir))\n",
83
+ "out_dir = os.path.join(out_dir, f\"sample_{sample_count}\")\n",
84
+ "os.makedirs(out_dir, exist_ok=True)\n",
85
+ "\n",
86
+ "prompts = [\n",
87
+ " \"1boy, casual, outdoors, sitting\", # source prompt\n",
88
+ " \"1boy, casual, outdoors, standing\" # target prompt\n",
89
+ "]\n",
90
+ "\n",
91
+ "# initialize the noise map\n",
92
+ "start_code = torch.randn([1, 4, 64, 64], device=device)\n",
93
+ "start_code = start_code.expand(len(prompts), -1, -1, -1)\n",
94
+ "\n",
95
+ "# inference the synthesized image without MasaCtrl\n",
96
+ "editor = AttentionBase()\n",
97
+ "regiter_attention_editor_diffusers(model, editor)\n",
98
+ "image_ori = model(prompts, latents=start_code, guidance_scale=7.5)\n",
99
+ "\n",
100
+ "# inference the synthesized image with MasaCtrl\n",
101
+ "STEP = 4\n",
102
+ "LAYPER = 10\n",
103
+ "\n",
104
+ "# hijack the attention module\n",
105
+ "editor = MutualSelfAttentionControl(STEP, LAYPER)\n",
106
+ "regiter_attention_editor_diffusers(model, editor)\n",
107
+ "\n",
108
+ "# inference the synthesized image\n",
109
+ "image_masactrl = model(prompts, latents=start_code, guidance_scale=7.5)[-1:]\n",
110
+ "\n",
111
+ "# save the synthesized image\n",
112
+ "out_image = torch.cat([image_ori, image_masactrl], dim=0)\n",
113
+ "save_image(out_image, os.path.join(out_dir, f\"all_step{STEP}_layer{LAYPER}.png\"))\n",
114
+ "save_image(out_image[0], os.path.join(out_dir, f\"source_step{STEP}_layer{LAYPER}.png\"))\n",
115
+ "save_image(out_image[1], os.path.join(out_dir, f\"without_step{STEP}_layer{LAYPER}.png\"))\n",
116
+ "save_image(out_image[2], os.path.join(out_dir, f\"masactrl_step{STEP}_layer{LAYPER}.png\"))\n",
117
+ "\n",
118
+ "print(\"Syntheiszed images are saved in\", out_dir)"
119
+ ]
120
+ }
121
+ ],
122
+ "metadata": {
123
+ "kernelspec": {
124
+ "display_name": "Python 3.8.5 ('ldm')",
125
+ "language": "python",
126
+ "name": "python3"
127
+ },
128
+ "language_info": {
129
+ "codemirror_mode": {
130
+ "name": "ipython",
131
+ "version": 3
132
+ },
133
+ "file_extension": ".py",
134
+ "mimetype": "text/x-python",
135
+ "name": "python",
136
+ "nbconvert_exporter": "python",
137
+ "pygments_lexer": "ipython3",
138
+ "version": "3.8.5"
139
+ },
140
+ "orig_nbformat": 4,
141
+ "vscode": {
142
+ "interpreter": {
143
+ "hash": "587aa04bacead72c1ffd459abbe4c8140b72ba2b534b24165b36a2ede3d95042"
144
+ }
145
+ }
146
+ },
147
+ "nbformat": 4,
148
+ "nbformat_minor": 2
149
+ }
playground_real.ipynb ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "### MasaCtrl: Tuning-free Mutual Self-Attention Control for Consistent Image Synthesis and Editing"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": null,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "import os\n",
17
+ "import torch\n",
18
+ "import torch.nn as nn\n",
19
+ "import torch.nn.functional as F\n",
20
+ "\n",
21
+ "import numpy as np\n",
22
+ "\n",
23
+ "from tqdm import tqdm\n",
24
+ "from einops import rearrange, repeat\n",
25
+ "from omegaconf import OmegaConf\n",
26
+ "\n",
27
+ "from diffusers import DDIMScheduler\n",
28
+ "\n",
29
+ "from masactrl.diffuser_utils import MasaCtrlPipeline\n",
30
+ "from masactrl.masactrl_utils import AttentionBase\n",
31
+ "from masactrl.masactrl_utils import regiter_attention_editor_diffusers\n",
32
+ "\n",
33
+ "from torchvision.utils import save_image\n",
34
+ "from torchvision.io import read_image\n",
35
+ "from pytorch_lightning import seed_everything\n",
36
+ "\n",
37
+ "torch.cuda.set_device(6) # set the GPU device"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "markdown",
42
+ "metadata": {},
43
+ "source": [
44
+ "#### Model Construction"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": null,
50
+ "metadata": {},
51
+ "outputs": [],
52
+ "source": [
53
+ "# Note that you may add your Hugging Face token to get access to the models\n",
54
+ "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
55
+ "# model_path = \"andite/anything-v4.0\"\n",
56
+ "model_path = \"CompVis/stable-diffusion-v1-4\"\n",
57
+ "# model_path = \"runwayml/stable-diffusion-v1-5\"\n",
58
+ "scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", clip_sample=False, set_alpha_to_one=False)\n",
59
+ "model = MasaCtrlPipeline.from_pretrained(model_path, scheduler=scheduler).to(device)"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "markdown",
64
+ "metadata": {},
65
+ "source": [
66
+ "#### Real editing with MasaCtrl"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": null,
72
+ "metadata": {},
73
+ "outputs": [],
74
+ "source": [
75
+ "from masactrl.masactrl import MutualSelfAttentionControl\n",
76
+ "from torchvision.io import read_image\n",
77
+ "\n",
78
+ "\n",
79
+ "def load_image(image_path, device):\n",
80
+ " image = read_image(image_path)\n",
81
+ " image = image[:3].unsqueeze_(0).float() / 127.5 - 1. # [-1, 1]\n",
82
+ " image = F.interpolate(image, (512, 512))\n",
83
+ " image = image.to(device)\n",
84
+ " return image\n",
85
+ "\n",
86
+ "\n",
87
+ "seed = 42\n",
88
+ "seed_everything(seed)\n",
89
+ "\n",
90
+ "out_dir = \"./workdir/masactrl_real_exp/\"\n",
91
+ "os.makedirs(out_dir, exist_ok=True)\n",
92
+ "sample_count = len(os.listdir(out_dir))\n",
93
+ "out_dir = os.path.join(out_dir, f\"sample_{sample_count}\")\n",
94
+ "os.makedirs(out_dir, exist_ok=True)\n",
95
+ "\n",
96
+ "# source image\n",
97
+ "SOURCE_IMAGE_PATH = \"./gradio_app/images/corgi.jpg\"\n",
98
+ "source_image = load_image(SOURCE_IMAGE_PATH, device)\n",
99
+ "\n",
100
+ "source_prompt = \"\"\n",
101
+ "target_prompt = \"a photo of a running corgi\"\n",
102
+ "prompts = [source_prompt, target_prompt]\n",
103
+ "\n",
104
+ "# invert the source image\n",
105
+ "start_code, latents_list = model.invert(source_image,\n",
106
+ " source_prompt,\n",
107
+ " guidance_scale=7.5,\n",
108
+ " num_inference_steps=50,\n",
109
+ " return_intermediates=True)\n",
110
+ "start_code = start_code.expand(len(prompts), -1, -1, -1)\n",
111
+ "\n",
112
+ "# results of direct synthesis\n",
113
+ "editor = AttentionBase()\n",
114
+ "regiter_attention_editor_diffusers(model, editor)\n",
115
+ "image_fixed = model([target_prompt],\n",
116
+ " latents=start_code[-1:],\n",
117
+ " num_inference_steps=50,\n",
118
+ " guidance_scale=7.5)\n",
119
+ "\n",
120
+ "# inference the synthesized image with MasaCtrl\n",
121
+ "STEP = 4\n",
122
+ "LAYPER = 10\n",
123
+ "\n",
124
+ "# hijack the attention module\n",
125
+ "editor = MutualSelfAttentionControl(STEP, LAYPER)\n",
126
+ "regiter_attention_editor_diffusers(model, editor)\n",
127
+ "\n",
128
+ "# inference the synthesized image\n",
129
+ "image_masactrl = model(prompts,\n",
130
+ " latents=start_code,\n",
131
+ " guidance_scale=7.5)\n",
132
+ "# Note: querying the inversion intermediate features latents_list\n",
133
+ "# may obtain better reconstruction and editing results\n",
134
+ "# image_masactrl = model(prompts,\n",
135
+ "# latents=start_code,\n",
136
+ "# guidance_scale=7.5,\n",
137
+ "# ref_intermediate_latents=latents_list)\n",
138
+ "\n",
139
+ "# save the synthesized image\n",
140
+ "out_image = torch.cat([source_image * 0.5 + 0.5,\n",
141
+ " image_masactrl[0:1],\n",
142
+ " image_fixed,\n",
143
+ " image_masactrl[-1:]], dim=0)\n",
144
+ "save_image(out_image, os.path.join(out_dir, f\"all_step{STEP}_layer{LAYPER}.png\"))\n",
145
+ "save_image(out_image[0], os.path.join(out_dir, f\"source_step{STEP}_layer{LAYPER}.png\"))\n",
146
+ "save_image(out_image[1], os.path.join(out_dir, f\"reconstructed_source_step{STEP}_layer{LAYPER}.png\"))\n",
147
+ "save_image(out_image[2], os.path.join(out_dir, f\"without_step{STEP}_layer{LAYPER}.png\"))\n",
148
+ "save_image(out_image[3], os.path.join(out_dir, f\"masactrl_step{STEP}_layer{LAYPER}.png\"))\n",
149
+ "\n",
150
+ "print(\"Syntheiszed images are saved in\", out_dir)"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "execution_count": null,
156
+ "metadata": {},
157
+ "outputs": [],
158
+ "source": []
159
+ }
160
+ ],
161
+ "metadata": {
162
+ "kernelspec": {
163
+ "display_name": "Python 3.8.5 ('ldm')",
164
+ "language": "python",
165
+ "name": "python3"
166
+ },
167
+ "language_info": {
168
+ "codemirror_mode": {
169
+ "name": "ipython",
170
+ "version": 3
171
+ },
172
+ "file_extension": ".py",
173
+ "mimetype": "text/x-python",
174
+ "name": "python",
175
+ "nbconvert_exporter": "python",
176
+ "pygments_lexer": "ipython3",
177
+ "version": "3.8.5"
178
+ },
179
+ "orig_nbformat": 4,
180
+ "vscode": {
181
+ "interpreter": {
182
+ "hash": "587aa04bacead72c1ffd459abbe4c8140b72ba2b534b24165b36a2ede3d95042"
183
+ }
184
+ }
185
+ },
186
+ "nbformat": 4,
187
+ "nbformat_minor": 2
188
+ }
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ diffusers==0.15.0
2
+ transformers
3
+ opencv-python
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }