ChenWu98 commited on
Commit
9fc2574
β€’
1 Parent(s): 334ea23

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +254 -17
app.py CHANGED
@@ -4,10 +4,19 @@ import torch
4
  from PIL import Image
5
  import utils
6
  import streamlit as st
 
 
 
 
 
 
 
 
7
 
8
  is_colab = utils.is_google_colab()
9
 
10
- if True:
 
11
  model_id_or_path = "CompVis/stable-diffusion-v1-4"
12
  scheduler = DDIMScheduler.from_config(model_id_or_path,
13
  use_auth_token=st.secrets["USER_TOKEN"],
@@ -15,21 +24,233 @@ if True:
15
  pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path,
16
  use_auth_token=st.secrets["USER_TOKEN"],
17
  scheduler=scheduler)
 
18
 
19
  if torch.cuda.is_available():
20
  pipe = pipe.to("cuda")
21
 
22
- device = "GPU πŸ”₯" if torch.cuda.is_available() else "CPU πŸ₯Ά"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
 
25
  def inference(source_prompt, target_prompt, source_guidance_scale=1, guidance_scale=5, num_inference_steps=100,
26
- width=512, height=512, seed=0, img=None, strength=0.7):
 
27
 
28
  torch.manual_seed(seed)
29
 
30
  ratio = min(height / img.height, width / img.width)
31
  img = img.resize((int(img.width * ratio), int(img.height * ratio)))
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  results = pipe(prompt=target_prompt,
34
  source_prompt=source_prompt,
35
  init_image=img,
@@ -64,7 +285,7 @@ with gr.Blocks(css=css) as demo:
64
  <a href="https://huggingface.co/docs/diffusers/main/en/api/pipelines/cycle_diffusion">🧨 Pipeline doc</a> | <a href="https://arxiv.org/abs/2210.05559">πŸ“„ Paper link</a>
65
  </p>
66
  <p>You can skip the queue in the colab: <a href="https://colab.research.google.com/gist/ChenWu98/0aa4fe7be80f6b45d3d055df9f14353a/copy-of-fine-tuned-diffusion-gradio.ipynb"><img data-canonical-src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a></p>
67
- Running on <b>{device}</b>{(" in a <b>Google Colab</b>." if is_colab else "")}
68
  </p>
69
  </div>
70
  """
@@ -82,42 +303,58 @@ with gr.Blocks(css=css) as demo:
82
  # ).style(grid=[1], height="auto")
83
 
84
  with gr.Column(scale=45):
85
- with gr.Tab("Options"):
86
  with gr.Group():
87
  with gr.Row():
88
  source_prompt = gr.Textbox(label="Source prompt", placeholder="Source prompt describes the input image")
 
89
  with gr.Row():
90
  target_prompt = gr.Textbox(label="Target prompt", placeholder="Target prompt describes the output image")
91
-
92
- with gr.Row():
93
- source_guidance_scale = gr.Slider(label="Source guidance scale", value=1, minimum=1, maximum=10)
94
  guidance_scale = gr.Slider(label="Target guidance scale", value=5, minimum=1, maximum=10)
95
-
96
  with gr.Row():
97
- num_inference_steps = gr.Slider(label="Number of inference steps", value=100, minimum=25, maximum=500, step=1)
98
  strength = gr.Slider(label="Strength", value=0.7, minimum=0.5, maximum=1, step=0.01)
99
 
100
  with gr.Row():
 
 
 
 
 
101
  width = gr.Slider(label="Width", value=512, minimum=64, maximum=1024, step=8)
102
  height = gr.Slider(label="Height", value=512, minimum=64, maximum=1024, step=8)
103
 
104
  with gr.Row():
105
  seed = gr.Slider(0, 2147483647, label='Seed', value=0, step=1)
 
 
 
106
  with gr.Row():
107
- generate = gr.Button(value="Edit")
 
 
 
 
108
 
109
  inputs = [source_prompt, target_prompt, source_guidance_scale, guidance_scale, num_inference_steps,
110
- width, height, seed, img, strength]
 
111
  generate.click(inference, inputs=inputs, outputs=image_out)
112
 
113
  ex = gr.Examples(
114
  [
115
- ["An astronaut riding a horse", "An astronaut riding an elephant", 1, 2, 100, 0, "images/astronaut_horse.png", 0.8],
116
- ["A black colored car.", "A blue colored car.", 1, 2, 100, 0, "images/black_car.png", 0.85],
117
- ["An aerial view of autumn scene.", "An aerial view of winter scene.", 1, 5, 100, 0, "images/mausoleum.png", 0.9],
118
- ["A green apple and a black backpack on the floor.", "A red apple and a black backpack on the floor.", 1, 7, 100, 0, "images/apple_bag.png", 0.9],
 
 
 
 
 
119
  ],
120
- [source_prompt, target_prompt, source_guidance_scale, guidance_scale, num_inference_steps, seed, img, strength],
 
 
121
  image_out, inference, cache_examples=False)
122
 
123
  gr.Markdown('''
 
4
  from PIL import Image
5
  import utils
6
  import streamlit as st
7
+ import ptp_utils
8
+ import seq_aligner
9
+ import torch.nn.functional as nnf
10
+ from typing import Optional, Union, Tuple, List, Callable, Dict
11
+ import abc
12
+
13
+ LOW_RESOURCE = False
14
+ MAX_NUM_WORDS = 77
15
 
16
  is_colab = utils.is_google_colab()
17
 
18
+
19
+ if False:
20
  model_id_or_path = "CompVis/stable-diffusion-v1-4"
21
  scheduler = DDIMScheduler.from_config(model_id_or_path,
22
  use_auth_token=st.secrets["USER_TOKEN"],
 
24
  pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path,
25
  use_auth_token=st.secrets["USER_TOKEN"],
26
  scheduler=scheduler)
27
+ tokenizer = pipe.tokenizer
28
 
29
  if torch.cuda.is_available():
30
  pipe = pipe.to("cuda")
31
 
32
+ device_print = "GPU πŸ”₯" if torch.cuda.is_available() else "CPU πŸ₯Ά"
33
+ device = "cuda" if torch.cuda.is_available() else "cpu"
34
+
35
+
36
+ class LocalBlend:
37
+
38
+ def __call__(self, x_t, attention_store):
39
+ k = 1
40
+ maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3]
41
+ maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, 16, 16, MAX_NUM_WORDS) for item in maps]
42
+ maps = torch.cat(maps, dim=1)
43
+ maps = (maps * self.alpha_layers).sum(-1).mean(1)
44
+ mask = nnf.max_pool2d(maps, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k))
45
+ mask = nnf.interpolate(mask, size=(x_t.shape[2:]))
46
+ mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
47
+ mask = mask.gt(self.threshold)
48
+ mask = (mask[:1] + mask[1:]).float()
49
+ x_t = x_t[:1] + mask * (x_t - x_t[:1])
50
+ return x_t
51
+
52
+ def __init__(self, prompts: List[str], words: [List[List[str]]], threshold=.3):
53
+ alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, MAX_NUM_WORDS)
54
+ for i, (prompt, words_) in enumerate(zip(prompts, words)):
55
+ if type(words_) is str:
56
+ words_ = [words_]
57
+ for word in words_:
58
+ ind = ptp_utils.get_word_inds(prompt, word, tokenizer)
59
+ alpha_layers[i, :, :, :, :, ind] = 1
60
+ self.alpha_layers = alpha_layers.to(device)
61
+ self.threshold = threshold
62
+
63
+
64
+ class AttentionControl(abc.ABC):
65
+
66
+ def step_callback(self, x_t):
67
+ return x_t
68
+
69
+ def between_steps(self):
70
+ return
71
+
72
+ @property
73
+ def num_uncond_att_layers(self):
74
+ return self.num_att_layers if LOW_RESOURCE else 0
75
+
76
+ @abc.abstractmethod
77
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
78
+ raise NotImplementedError
79
+
80
+ def __call__(self, attn, is_cross: bool, place_in_unet: str):
81
+ if self.cur_att_layer >= self.num_uncond_att_layers:
82
+ if LOW_RESOURCE:
83
+ attn = self.forward(attn, is_cross, place_in_unet)
84
+ else:
85
+ h = attn.shape[0]
86
+ attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
87
+ self.cur_att_layer += 1
88
+ if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
89
+ self.cur_att_layer = 0
90
+ self.cur_step += 1
91
+ self.between_steps()
92
+ return attn
93
+
94
+ def reset(self):
95
+ self.cur_step = 0
96
+ self.cur_att_layer = 0
97
+
98
+ def __init__(self):
99
+ self.cur_step = 0
100
+ self.num_att_layers = -1
101
+ self.cur_att_layer = 0
102
+
103
+
104
+ class EmptyControl(AttentionControl):
105
+
106
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
107
+ return attn
108
+
109
+
110
+ class AttentionStore(AttentionControl):
111
+
112
+ @staticmethod
113
+ def get_empty_store():
114
+ return {"down_cross": [], "mid_cross": [], "up_cross": [],
115
+ "down_self": [], "mid_self": [], "up_self": []}
116
+
117
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
118
+ key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
119
+ if attn.shape[1] <= 32 ** 2: # avoid memory overhead
120
+ self.step_store[key].append(attn)
121
+ return attn
122
+
123
+ def between_steps(self):
124
+ if len(self.attention_store) == 0:
125
+ self.attention_store = self.step_store
126
+ else:
127
+ for key in self.attention_store:
128
+ for i in range(len(self.attention_store[key])):
129
+ self.attention_store[key][i] += self.step_store[key][i]
130
+ self.step_store = self.get_empty_store()
131
+
132
+ def get_average_attention(self):
133
+ average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store}
134
+ return average_attention
135
+
136
+ def reset(self):
137
+ super(AttentionStore, self).reset()
138
+ self.step_store = self.get_empty_store()
139
+ self.attention_store = {}
140
+
141
+ def __init__(self):
142
+ super(AttentionStore, self).__init__()
143
+ self.step_store = self.get_empty_store()
144
+ self.attention_store = {}
145
+
146
+
147
+ class AttentionControlEdit(AttentionStore, abc.ABC):
148
+
149
+ def step_callback(self, x_t):
150
+ if self.local_blend is not None:
151
+ x_t = self.local_blend(x_t, self.attention_store)
152
+ return x_t
153
+
154
+ def replace_self_attention(self, attn_base, att_replace):
155
+ if att_replace.shape[2] <= 16 ** 2:
156
+ return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
157
+ else:
158
+ return att_replace
159
+
160
+ @abc.abstractmethod
161
+ def replace_cross_attention(self, attn_base, att_replace):
162
+ raise NotImplementedError
163
+
164
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
165
+ super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)
166
+ if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]):
167
+ h = attn.shape[0] // self.batch_size
168
+ attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
169
+ attn_base, attn_repalce = attn[0], attn[1:]
170
+ if is_cross:
171
+ alpha_words = self.cross_replace_alpha[self.cur_step]
172
+ attn_replace_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + (1 - alpha_words) * attn_repalce
173
+ attn[1:] = attn_replace_new
174
+ else:
175
+ attn[1:] = self.replace_self_attention(attn_base, attn_repalce)
176
+ attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
177
+ return attn
178
+
179
+ def __init__(self, prompts, num_steps: int,
180
+ cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
181
+ self_replace_steps: Union[float, Tuple[float, float]],
182
+ local_blend: Optional[LocalBlend]):
183
+ super(AttentionControlEdit, self).__init__()
184
+ self.batch_size = len(prompts)
185
+ self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps, tokenizer).to(device)
186
+ if type(self_replace_steps) is float:
187
+ self_replace_steps = 0, self_replace_steps
188
+ self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
189
+ self.local_blend = local_blend
190
+
191
+
192
+ class AttentionReplace(AttentionControlEdit):
193
+
194
+ def replace_cross_attention(self, attn_base, att_replace):
195
+ return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper)
196
+
197
+ def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
198
+ local_blend: Optional[LocalBlend] = None):
199
+ super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
200
+ self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device)
201
+
202
+
203
+ class AttentionRefine(AttentionControlEdit):
204
+
205
+ def replace_cross_attention(self, attn_base, att_replace):
206
+ attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3)
207
+ attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas)
208
+ return attn_replace
209
+
210
+ def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
211
+ local_blend: Optional[LocalBlend] = None):
212
+ super(AttentionRefine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
213
+ self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer)
214
+ self.mapper, alphas = self.mapper.to(device), alphas.to(device)
215
+ self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1])
216
+
217
+
218
+ def get_equalizer(text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float], Tuple[float, ...]]):
219
+ if type(word_select) is int or type(word_select) is str:
220
+ word_select = (word_select,)
221
+ equalizer = torch.ones(len(values), 77)
222
+ values = torch.tensor(values, dtype=torch.float32)
223
+ for word in word_select:
224
+ inds = ptp_utils.get_word_inds(text, word, tokenizer)
225
+ equalizer[:, inds] = values
226
+ return equalizer
227
 
228
 
229
  def inference(source_prompt, target_prompt, source_guidance_scale=1, guidance_scale=5, num_inference_steps=100,
230
+ width=512, height=512, seed=0, img=None, strength=0.7,
231
+ cross_attention_control=None, cross_replace_steps=0.8, self_replace_steps=0.4):
232
 
233
  torch.manual_seed(seed)
234
 
235
  ratio = min(height / img.height, width / img.width)
236
  img = img.resize((int(img.width * ratio), int(img.height * ratio)))
237
 
238
+ # create the CAC controller.
239
+ if cross_attention_control == "replace":
240
+ controller = AttentionReplace([source_prompt, target_prompt],
241
+ num_inference_steps,
242
+ cross_replace_steps=cross_replace_steps,
243
+ self_replace_steps=self_replace_steps,
244
+ )
245
+ ptp_utils.register_attention_control(pipe, controller)
246
+ elif cross_attention_control == "refine":
247
+ controller = AttentionRefine([source_prompt, target_prompt],
248
+ num_inference_steps,
249
+ cross_replace_steps=cross_replace_steps,
250
+ self_replace_steps=self_replace_steps,
251
+ )
252
+ ptp_utils.register_attention_control(pipe, controller)
253
+
254
  results = pipe(prompt=target_prompt,
255
  source_prompt=source_prompt,
256
  init_image=img,
 
285
  <a href="https://huggingface.co/docs/diffusers/main/en/api/pipelines/cycle_diffusion">🧨 Pipeline doc</a> | <a href="https://arxiv.org/abs/2210.05559">πŸ“„ Paper link</a>
286
  </p>
287
  <p>You can skip the queue in the colab: <a href="https://colab.research.google.com/gist/ChenWu98/0aa4fe7be80f6b45d3d055df9f14353a/copy-of-fine-tuned-diffusion-gradio.ipynb"><img data-canonical-src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"></a></p>
288
+ Running on <b>{device_print}</b>{(" in a <b>Google Colab</b>." if is_colab else "")}
289
  </p>
290
  </div>
291
  """
 
303
  # ).style(grid=[1], height="auto")
304
 
305
  with gr.Column(scale=45):
306
+ with gr.Tab("Edit options"):
307
  with gr.Group():
308
  with gr.Row():
309
  source_prompt = gr.Textbox(label="Source prompt", placeholder="Source prompt describes the input image")
310
+ source_guidance_scale = gr.Slider(label="Source guidance scale", value=1, minimum=1, maximum=10)
311
  with gr.Row():
312
  target_prompt = gr.Textbox(label="Target prompt", placeholder="Target prompt describes the output image")
 
 
 
313
  guidance_scale = gr.Slider(label="Target guidance scale", value=5, minimum=1, maximum=10)
 
314
  with gr.Row():
 
315
  strength = gr.Slider(label="Strength", value=0.7, minimum=0.5, maximum=1, step=0.01)
316
 
317
  with gr.Row():
318
+ generate = gr.Button(value="Edit")
319
+ with gr.Tab("Basic options"):
320
+ with gr.Group():
321
+ with gr.Row():
322
+ num_inference_steps = gr.Slider(label="Number of inference steps", value=100, minimum=25, maximum=500, step=1)
323
  width = gr.Slider(label="Width", value=512, minimum=64, maximum=1024, step=8)
324
  height = gr.Slider(label="Height", value=512, minimum=64, maximum=1024, step=8)
325
 
326
  with gr.Row():
327
  seed = gr.Slider(0, 2147483647, label='Seed', value=0, step=1)
328
+
329
+ with gr.Tab("CAC options"):
330
+ with gr.Group():
331
  with gr.Row():
332
+ cross_attention_control = gr.Radio(label="CAC type", choices=["None", "Replace", "Refine"], value="None")
333
+ with gr.Row():
334
+ # If not "None", the following two parameters will be used.
335
+ cross_replace_steps = gr.Slider(label="Cross replace steps", value=0.8, minimum=0.0, maximum=1, step=0.01)
336
+ self_replace_steps = gr.Slider(label="Self replace steps", value=0.4, minimum=0.0, maximum=1, step=0.01)
337
 
338
  inputs = [source_prompt, target_prompt, source_guidance_scale, guidance_scale, num_inference_steps,
339
+ width, height, seed, img, strength,
340
+ cross_attention_control, cross_replace_steps, self_replace_steps]
341
  generate.click(inference, inputs=inputs, outputs=image_out)
342
 
343
  ex = gr.Examples(
344
  [
345
+ ["An astronaut riding a horse", "An astronaut riding an elephant", 1, 2, 100, "images/astronaut_horse.png", 0.8, "None", 0, 0],
346
+ ["An astronaut riding a horse", "An astronaut riding a elephant", 1, 2, 100, "images/astronaut_horse.png", 0.9, "Replace", 0.15, 0.10],
347
+ ["A black colored car.", "A blue colored car.", 1, 2, 100, "images/black_car.png", 0.85, "None", 0, 0],
348
+ ["A black colored car.", "A blue colored car.", 1, 5, 100, "images/black_car.png", 0.95, "Replace", 0.8, 0.4],
349
+ ["A black colored car.", "A red colored car.", 1, 5, 100, "images/black_car.png", 1, "Replace", 0.8, 0.4],
350
+ ["An aerial view of autumn scene.", "An aerial view of winter scene.", 1, 5, 100, "images/mausoleum.png", 0.9, "None", 0.0, 0.0],
351
+ ["An aerial view of autumn scene.", "An aerial view of winter scene.", 1, 5, 100, "images/mausoleum.png", 1, "Replace", 0.8, 0.4],
352
+ ["A green apple and a black backpack on the floor.", "A red apple and a black backpack on the floor.", 1, 7, 100, "images/apple_bag.png", 0.9, "None", 0.0, 0.0],
353
+ ["A green apple and a black backpack on the floor.", "A red apple and a black backpack on the floor.", 1, 7, 100, "images/apple_bag.png", 0.9, "Replace", 0.8, 0.4],
354
  ],
355
+ [source_prompt, target_prompt, source_guidance_scale, guidance_scale, num_inference_steps,
356
+ img, strength,
357
+ cross_attention_control, cross_replace_steps, self_replace_steps],
358
  image_out, inference, cache_examples=False)
359
 
360
  gr.Markdown('''