6kplus commited on
Commit
948429b
1 Parent(s): 0a4f406

Upload 30 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ images/InfEdit.jpg filter=lfs diff=lfs merge=lfs -text
37
+ images/sam.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,6 @@
1
  ---
2
  title: InfEdit
3
- emoji: 🌍
4
- colorFrom: purple
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.9.0
8
- app_file: app.py
9
- pinned: false
10
- license: cc-by-nc-sa-4.0
11
  ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: InfEdit
3
+ app_file: app_ead_instuct.py
 
 
4
  sdk: gradio
5
+ sdk_version: 4.7.1
 
 
 
6
  ---
 
 
app_ead_instuct.py ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import LCMScheduler
2
+ from pipeline_ead import EditPipeline
3
+ import os
4
+ import gradio as gr
5
+ import torch
6
+ from PIL import Image
7
+ import torch.nn.functional as nnf
8
+ from typing import Optional, Union, Tuple, List, Callable, Dict
9
+ import abc
10
+ import ptp_utils
11
+ import utils
12
+ import numpy as np
13
+ import seq_aligner
14
+ import math
15
+
16
+ LOW_RESOURCE = False
17
+ MAX_NUM_WORDS = 77
18
+
19
+ is_colab = utils.is_google_colab()
20
+ colab_instruction = "" if is_colab else """
21
+ Colab Instuction"""
22
+
23
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
24
+ model_id_or_path = "SimianLuo/LCM_Dreamshaper_v7"
25
+ device_print = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶"
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+
28
+ if is_colab:
29
+ scheduler = LCMScheduler.from_config(model_id_or_path, subfolder="scheduler")
30
+ pipe = EditPipeline.from_pretrained(model_id_or_path, scheduler=scheduler, torch_dtype=torch_dtype)
31
+ else:
32
+ # import streamlit as st
33
+ # scheduler = DDIMScheduler.from_config(model_id_or_path, use_auth_token=st.secrets["USER_TOKEN"], subfolder="scheduler")
34
+ # pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, use_auth_token=st.secrets["USER_TOKEN"], scheduler=scheduler, torch_dtype=torch_dtype)
35
+ scheduler = LCMScheduler.from_config(model_id_or_path, use_auth_token=os.environ.get("USER_TOKEN"), subfolder="scheduler")
36
+ pipe = EditPipeline.from_pretrained(model_id_or_path, use_auth_token=os.environ.get("USER_TOKEN"), scheduler=scheduler, torch_dtype=torch_dtype)
37
+
38
+ tokenizer = pipe.tokenizer
39
+ encoder = pipe.text_encoder
40
+
41
+ if torch.cuda.is_available():
42
+ pipe = pipe.to("cuda")
43
+
44
+
45
+ class LocalBlend:
46
+
47
+ def get_mask(self,x_t,maps,word_idx, thresh, i):
48
+ # print(word_idx)
49
+ # print(maps.shape)
50
+ # for i in range(0,self.len):
51
+ # self.save_image(maps[:,:,:,:,i].mean(0,keepdim=True),i,"map")
52
+ maps = maps * word_idx.reshape(1,1,1,1,-1)
53
+ maps = (maps[:,:,:,:,1:self.len-1]).mean(0,keepdim=True)
54
+ # maps = maps.mean(0,keepdim=True)
55
+ maps = (maps).max(-1)[0]
56
+ # self.save_image(maps,i,"map")
57
+ maps = nnf.interpolate(maps, size=(x_t.shape[2:]))
58
+ # maps = maps.mean(1,keepdim=True)\
59
+ maps = maps / maps.max(2, keepdim=True)[0].max(3, keepdim=True)[0]
60
+ mask = maps > thresh
61
+ return mask
62
+
63
+
64
+ def save_image(self,mask,i, caption):
65
+ image = mask[0, 0, :, :]
66
+ image = 255 * image / image.max()
67
+ # print(image.shape)
68
+ image = image.unsqueeze(-1).expand(*image.shape, 3)
69
+ # print(image.shape)
70
+ image = image.cpu().numpy().astype(np.uint8)
71
+ image = np.array(Image.fromarray(image).resize((256, 256)))
72
+ if not os.path.exists(f"inter/{caption}"):
73
+ os.mkdir(f"inter/{caption}")
74
+ ptp_utils.save_images(image, f"inter/{caption}/{i}.jpg")
75
+
76
+
77
+ def __call__(self, i, x_s, x_t, x_m, attention_store, alpha_prod, temperature=0.15, use_xm=False):
78
+ maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3]
79
+ h,w = x_t.shape[2],x_t.shape[3]
80
+ h , w = ((h+1)//2+1)//2, ((w+1)//2+1)//2
81
+ # print(h,w)
82
+ # print(maps[0].shape)
83
+ maps = [item.reshape(2, -1, 1, h // int((h*w/item.shape[-2])**0.5), w // int((h*w/item.shape[-2])**0.5), MAX_NUM_WORDS) for item in maps]
84
+ maps = torch.cat(maps, dim=1)
85
+ maps_s = maps[0,:]
86
+ maps_m = maps[1,:]
87
+ thresh_e = temperature / alpha_prod ** (0.5)
88
+ if thresh_e < self.thresh_e:
89
+ thresh_e = self.thresh_e
90
+ thresh_m = self.thresh_m
91
+ mask_e = self.get_mask(x_t, maps_m, self.alpha_e, thresh_e, i)
92
+ mask_m = self.get_mask(x_t, maps_s, (self.alpha_m-self.alpha_me), thresh_m, i)
93
+ mask_me = self.get_mask(x_t, maps_m, self.alpha_me, self.thresh_e, i)
94
+ if self.save_inter:
95
+ self.save_image(mask_e,i,"mask_e")
96
+ self.save_image(mask_m,i,"mask_m")
97
+ self.save_image(mask_me,i,"mask_me")
98
+
99
+ if self.alpha_e.sum() == 0:
100
+ x_t_out = x_t
101
+ else:
102
+ x_t_out = torch.where(mask_e, x_t, x_m)
103
+ x_t_out = torch.where(mask_m, x_s, x_t_out)
104
+ if use_xm:
105
+ x_t_out = torch.where(mask_me, x_m, x_t_out)
106
+
107
+ return x_m, x_t_out
108
+
109
+ def __init__(self,thresh_e=0.3, thresh_m=0.3, save_inter = False):
110
+ self.thresh_e = thresh_e
111
+ self.thresh_m = thresh_m
112
+ self.save_inter = save_inter
113
+
114
+ def set_map(self, ms, alpha, alpha_e, alpha_m,len):
115
+ self.m = ms
116
+ self.alpha = alpha
117
+ self.alpha_e = alpha_e
118
+ self.alpha_m = alpha_m
119
+ alpha_me = alpha_e.to(torch.bool) & alpha_m.to(torch.bool)
120
+ self.alpha_me = alpha_me.to(torch.float)
121
+ self.len = len
122
+
123
+
124
+ class AttentionControl(abc.ABC):
125
+
126
+ def step_callback(self, x_t):
127
+ return x_t
128
+
129
+ def between_steps(self):
130
+ return
131
+
132
+ @property
133
+ def num_uncond_att_layers(self):
134
+ return self.num_att_layers if LOW_RESOURCE else 0
135
+
136
+ @abc.abstractmethod
137
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
138
+ raise NotImplementedError
139
+
140
+ def __call__(self, attn, is_cross: bool, place_in_unet: str):
141
+ if self.cur_att_layer >= self.num_uncond_att_layers:
142
+ if LOW_RESOURCE:
143
+ attn = self.forward(attn, is_cross, place_in_unet)
144
+ else:
145
+ h = attn.shape[0]
146
+ attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
147
+ self.cur_att_layer += 1
148
+ if self.cur_att_layer == self.num_att_layers // 2 + self.num_uncond_att_layers:
149
+ self.cur_att_layer = 0
150
+ self.cur_step += 1
151
+ self.between_steps()
152
+ return attn
153
+
154
+ def reset(self):
155
+ self.cur_step = 0
156
+ self.cur_att_layer = 0
157
+
158
+ def __init__(self):
159
+ self.cur_step = 0
160
+ self.num_att_layers = -1
161
+ self.cur_att_layer = 0
162
+
163
+
164
+ class EmptyControl(AttentionControl):
165
+
166
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
167
+ return attn
168
+ def self_attn_forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
169
+ b = q.shape[0] // num_heads
170
+ out = torch.einsum("h i j, h j d -> h i d", attn, v)
171
+ return out
172
+
173
+
174
+ class AttentionStore(AttentionControl):
175
+
176
+ @staticmethod
177
+ def get_empty_store():
178
+ return {"down_cross": [], "mid_cross": [], "up_cross": [],
179
+ "down_self": [], "mid_self": [], "up_self": []}
180
+
181
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
182
+ key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
183
+ if attn.shape[1] <= 32 ** 2: # avoid memory overhead
184
+ self.step_store[key].append(attn)
185
+ return attn
186
+
187
+ def between_steps(self):
188
+ if len(self.attention_store) == 0:
189
+ self.attention_store = self.step_store
190
+ else:
191
+ for key in self.attention_store:
192
+ for i in range(len(self.attention_store[key])):
193
+ self.attention_store[key][i] += self.step_store[key][i]
194
+ self.step_store = self.get_empty_store()
195
+
196
+ def get_average_attention(self):
197
+ average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store}
198
+ return average_attention
199
+
200
+ def reset(self):
201
+ super(AttentionStore, self).reset()
202
+ self.step_store = self.get_empty_store()
203
+ self.attention_store = {}
204
+
205
+ def __init__(self):
206
+ super(AttentionStore, self).__init__()
207
+ self.step_store = self.get_empty_store()
208
+ self.attention_store = {}
209
+
210
+
211
+ class AttentionControlEdit(AttentionStore, abc.ABC):
212
+
213
+ def step_callback(self,i, t, x_s, x_t, x_m, alpha_prod):
214
+ if (self.local_blend is not None) and (i>0):
215
+ use_xm = (self.cur_step+self.start_steps+1 == self.num_steps)
216
+ x_m, x_t = self.local_blend(i, x_s, x_t, x_m, self.attention_store, alpha_prod, use_xm=use_xm)
217
+ return x_m, x_t
218
+
219
+ def replace_self_attention(self, attn_base, att_replace):
220
+ if att_replace.shape[2] <= 16 ** 2:
221
+ return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
222
+ else:
223
+ return att_replace
224
+
225
+ @abc.abstractmethod
226
+ def replace_cross_attention(self, attn_base, att_replace):
227
+ raise NotImplementedError
228
+
229
+ def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
230
+ b = q.shape[0] // num_heads
231
+
232
+ sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale")
233
+ attn = sim.softmax(-1)
234
+ out = torch.einsum("h i j, h j d -> h i d", attn, v)
235
+ return out
236
+
237
+ def self_attn_forward(self, q, k, v, num_heads):
238
+ if q.shape[0]//num_heads == 3:
239
+ if (self.self_replace_steps <= ((self.cur_step+self.start_steps+1)*1.0 / self.num_steps) ):
240
+ q=torch.cat([q[:num_heads*2],q[num_heads:num_heads*2]])
241
+ k=torch.cat([k[:num_heads*2],k[:num_heads]])
242
+ v=torch.cat([v[:num_heads*2],v[:num_heads]])
243
+ else:
244
+ q=torch.cat([q[:num_heads],q[:num_heads],q[:num_heads]])
245
+ k=torch.cat([k[:num_heads],k[:num_heads],k[:num_heads]])
246
+ v=torch.cat([v[:num_heads*2],v[:num_heads]])
247
+ return q,k,v
248
+ else:
249
+ qu, qc = q.chunk(2)
250
+ ku, kc = k.chunk(2)
251
+ vu, vc = v.chunk(2)
252
+ if (self.self_replace_steps <= ((self.cur_step+self.start_steps+1)*1.0 / self.num_steps) ):
253
+ qu=torch.cat([qu[:num_heads*2],qu[num_heads:num_heads*2]])
254
+ qc=torch.cat([qc[:num_heads*2],qc[num_heads:num_heads*2]])
255
+ ku=torch.cat([ku[:num_heads*2],ku[:num_heads]])
256
+ kc=torch.cat([kc[:num_heads*2],kc[:num_heads]])
257
+ vu=torch.cat([vu[:num_heads*2],vu[:num_heads]])
258
+ vc=torch.cat([vc[:num_heads*2],vc[:num_heads]])
259
+ else:
260
+ qu=torch.cat([qu[:num_heads],qu[:num_heads],qu[:num_heads]])
261
+ qc=torch.cat([qc[:num_heads],qc[:num_heads],qc[:num_heads]])
262
+ ku=torch.cat([ku[:num_heads],ku[:num_heads],ku[:num_heads]])
263
+ kc=torch.cat([kc[:num_heads],kc[:num_heads],kc[:num_heads]])
264
+ vu=torch.cat([vu[:num_heads*2],vu[:num_heads]])
265
+ vc=torch.cat([vc[:num_heads*2],vc[:num_heads]])
266
+
267
+ return torch.cat([qu, qc], dim=0) ,torch.cat([ku, kc], dim=0), torch.cat([vu, vc], dim=0)
268
+
269
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
270
+ if is_cross :
271
+ h = attn.shape[0] // self.batch_size
272
+ attn = attn.reshape(self.batch_size,h, *attn.shape[1:])
273
+ attn_base, attn_repalce,attn_masa = attn[0], attn[1], attn[2]
274
+ attn_replace_new = self.replace_cross_attention(attn_masa, attn_repalce)
275
+ attn_base_store = self.replace_cross_attention(attn_base, attn_repalce)
276
+ if (self.cross_replace_steps >= ((self.cur_step+self.start_steps+1)*1.0 / self.num_steps) ):
277
+ attn[1] = attn_base_store
278
+ attn_store=torch.cat([attn_base_store,attn_replace_new])
279
+ attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
280
+ attn_store = attn_store.reshape(2 *h, *attn_store.shape[2:])
281
+ super(AttentionControlEdit, self).forward(attn_store, is_cross, place_in_unet)
282
+ return attn
283
+
284
+ def __init__(self, prompts, num_steps: int,start_steps: int,
285
+ cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
286
+ self_replace_steps: Union[float, Tuple[float, float]],
287
+ local_blend: Optional[LocalBlend]):
288
+ super(AttentionControlEdit, self).__init__()
289
+ self.batch_size = len(prompts)+1
290
+ self.self_replace_steps = self_replace_steps
291
+ self.cross_replace_steps = cross_replace_steps
292
+ self.num_steps=num_steps
293
+ self.start_steps=start_steps
294
+ self.local_blend = local_blend
295
+
296
+
297
+ class AttentionReplace(AttentionControlEdit):
298
+
299
+ def replace_cross_attention(self, attn_base, att_replace):
300
+ return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper)
301
+
302
+ def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
303
+ local_blend: Optional[LocalBlend] = None):
304
+ super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend)
305
+ self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device).to(torch_dtype)
306
+
307
+
308
+ class AttentionRefine(AttentionControlEdit):
309
+
310
+ def replace_cross_attention(self, attn_masa, att_replace):
311
+ attn_masa_replace = attn_masa[:, :, self.mapper].squeeze()
312
+ attn_replace = attn_masa_replace * self.alphas + \
313
+ att_replace * (1 - self.alphas)
314
+ return attn_replace
315
+
316
+ def __init__(self, prompts, prompt_specifiers, num_steps: int,start_steps: int, cross_replace_steps: float, self_replace_steps: float,
317
+ local_blend: Optional[LocalBlend] = None):
318
+ super(AttentionRefine, self).__init__(prompts, num_steps,start_steps, cross_replace_steps, self_replace_steps, local_blend)
319
+ self.mapper, alphas, ms, alpha_e, alpha_m = seq_aligner.get_refinement_mapper(prompts, prompt_specifiers, tokenizer, encoder, device)
320
+ self.mapper, alphas, ms = self.mapper.to(device), alphas.to(device).to(torch_dtype), ms.to(device).to(torch_dtype)
321
+ self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1])
322
+ self.ms = ms.reshape(ms.shape[0], 1, 1, ms.shape[1])
323
+ ms = ms.to(device)
324
+ alpha_e = alpha_e.to(device)
325
+ alpha_m = alpha_m.to(device)
326
+ t_len = len(tokenizer(prompts[1])["input_ids"])
327
+ self.local_blend.set_map(ms,alphas,alpha_e,alpha_m,t_len)
328
+
329
+
330
+ def get_equalizer(text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float], Tuple[float, ...]]):
331
+ if type(word_select) is int or type(word_select) is str:
332
+ word_select = (word_select,)
333
+ equalizer = torch.ones(len(values), 77)
334
+ values = torch.tensor(values, dtype=torch_dtype)
335
+ for word in word_select:
336
+ inds = ptp_utils.get_word_inds(text, word, tokenizer)
337
+ equalizer[:, inds] = values
338
+ return equalizer
339
+
340
+
341
+ def inference(img, source_prompt, target_prompt,
342
+ local, mutual,
343
+ positive_prompt, negative_prompt,
344
+ guidance_s, guidance_t,
345
+ num_inference_steps,
346
+ width, height, seed, strength,
347
+ cross_replace_steps, self_replace_steps,
348
+ thresh_e, thresh_m, denoise, user_instruct="", api_key=""):
349
+ print(img)
350
+ if user_instruct != "" and api_key != "":
351
+ source_prompt, target_prompt, local, mutual, replace_steps, num_inference_steps = get_params(api_key, user_instruct)
352
+ cross_replace_steps = replace_steps
353
+ self_replace_steps = replace_steps
354
+
355
+ torch.manual_seed(seed)
356
+ ratio = min(height / img.height, width / img.width)
357
+ img = img.resize((int(img.width * ratio), int(img.height * ratio)))
358
+ if denoise is False:
359
+ strength = 1
360
+ num_denoise_num = math.trunc(num_inference_steps*strength)
361
+ num_start = num_inference_steps-num_denoise_num
362
+ # create the CAC controller.
363
+ local_blend = LocalBlend(thresh_e=thresh_e, thresh_m=thresh_m, save_inter=False)
364
+ controller = AttentionRefine([source_prompt, target_prompt],[[local, mutual]],
365
+ num_inference_steps,
366
+ num_start,
367
+ cross_replace_steps=cross_replace_steps,
368
+ self_replace_steps=self_replace_steps,
369
+ local_blend=local_blend
370
+ )
371
+ ptp_utils.register_attention_control(pipe, controller)
372
+
373
+ results = pipe(prompt=target_prompt,
374
+ source_prompt=source_prompt,
375
+ positive_prompt=positive_prompt,
376
+ negative_prompt=negative_prompt,
377
+ image=img,
378
+ num_inference_steps=num_inference_steps,
379
+ eta=1,
380
+ strength=strength,
381
+ guidance_scale=guidance_t,
382
+ source_guidance_scale=guidance_s,
383
+ denoise_model=denoise,
384
+ callback = controller.step_callback
385
+ )
386
+
387
+ return replace_nsfw_images(results)
388
+
389
+
390
+ def replace_nsfw_images(results):
391
+ for i in range(len(results.images)):
392
+ if results.nsfw_content_detected[i]:
393
+ results.images[i] = Image.open("nsfw.png")
394
+ return results.images[0]
395
+
396
+
397
+ css = """.cycle-diffusion-div div{display:inline-flex;align-items:center;gap:.8rem;font-size:1.75rem}.cycle-diffusion-div div h1{font-weight:900;margin-bottom:7px}.cycle-diffusion-div p{margin-bottom:10px;font-size:94%}.cycle-diffusion-div p a{text-decoration:underline}.tabs{margin-top:0;margin-bottom:0}#gallery{min-height:20rem}
398
+ """
399
+ intro = """
400
+ <div style="display: flex;align-items: center;justify-content: center">
401
+ <img src="https://sled-group.github.io/InfEdit/image_assets/InfEdit.png" width="80" style="display: inline-block">
402
+ <h1 style="margin-left: 12px;text-align: center;margin-bottom: 7px;display: inline-block">InfEdit</h1>
403
+ <h3 style="display: inline-block;margin-left: 10px;margin-top: 6px;font-weight: 500">Inversion-Free Image Editing
404
+ with Natural Language</h3>
405
+ </div>
406
+ """
407
+
408
+ param_bot_prompt = """
409
+ You are a helpful assistant named InfEdit that provides input parameters to the image editing model based on user instructions. You should respond in valid json format.
410
+
411
+ User:
412
+ ```
413
+ {image descrption and editing commands | example: 'The image shows an apple on the table and I want to change the apple to a banana.'}
414
+ ```
415
+
416
+ After receiving this, you will need to generate the appropriate params as input to the image editing models.
417
+
418
+ Assistant:
419
+ ```
420
+ {
421
+ “source_prompt”: “{a string describes the input image, it needs to includes the thing user want to change | example: 'an apple on the table'}”,
422
+ “target_prompt”: “{a string that matches the source prompt, but it needs to includes the thing user want to change | example: 'a banana on the table'}”,
423
+ “target_sub”: “{a special substring from the target prompt}”,
424
+ “mutual_sub”: “{a special mutual substring from source/target prompt}”
425
+ “attention_control”: {a number between 0 and 1}
426
+ “steps”: {a number between 8 and 50}
427
+ }
428
+ ```
429
+
430
+ You need to fill in the "target_sub" and "mutual_sub" by the guideline below.
431
+
432
+ If the editing instruction is not about changing style or background:
433
+ - The "target_sub" should be a special substring from the target prompt that highlights what you want to edit, it should be as short as possible and should only be noun ("banana" instead of "a banana").
434
+ - The "mutual_sub" should be kept as an empty string.
435
+ P.S. When you want to remove something, it's always better to use "empty", "nothing" or some appropriate words to replace it. Like remove an apple on the table, you can use "an apple on the table" and "nothing on the table" as your prompts, and use "nothing" as your target_sub.
436
+ P.S. You should think carefully about what you want to modify, like "short hair" to "long hair", your target_sub should be "hair" instead of "long".
437
+ P.S. When you are adding something, the target_sub should be the thing you want to add.
438
+
439
+ If it's about style editing:
440
+ - The "target_sub" should be kept as an empty string.
441
+ - The "mutual_sub" should be kept as an empty string.
442
+
443
+ If it's about background editing:
444
+ - The "target_sub" should be kept as an empty string.
445
+ - The "mutual_sub" should be a common substring from source/target prompt, and is the main object/character (noun) in the image. It should be as short as possible and only be noun ("banana" instead of "a banana", "man" instead of "running man").
446
+
447
+ A specific case, if it's about change an object's abstract information, like pose, view or shape and want to keep the semantic feature same, like a dog to a running dog,
448
+ - The "target_sub" should be a special substring from the target prompt that highlights what you want to edit, it should be as short as possible and should only be noun ("dog" instead of "a running dog").
449
+ - The "mutual_sub" should be as same as target_sub because we want to "edit the dog but also keep the dog as same".
450
+
451
+
452
+ You need to choose a specific value of “attention_control” by the guideline below.
453
+ A larger value of “attention_control” means more consistency between the source image and the output.
454
+
455
+ - the editing is on the feature level, like color, material and so on, and want to ensure the characteristics of the original object as much as possible, you should choose a large value. (Example: for color editing, you can choose 1, and for material you can choose 0.9)
456
+ - the editing is on the object level, like edit a "cat" to a "dog", or a "horse" to a "zebra", and want to make them to be similar, you need to choose a relatively large value, we say 0.7 for example.
457
+ - the editing is changing the style but want to keep the spatial features, you need to choose a relatively large value, we say 0.7 for example.
458
+ - the editing need to change something's shape, like edit an "apple" to a "banana", a "flower" to a "knife", "short" hair to "long" hair, "round" to "square", which have very different shapes, you need to choose a relatively small value, we say 0.3 for example.
459
+ - the editing is tring to change the spatial information, like change the pose and so on, you need to choose a relatively small value, we say 0.3 for example.
460
+ - the editing should not consider the consistency with the input image, like add something new, remove something, or change the background, you can directly use 0.
461
+
462
+
463
+ You need to choose a specific value of “steps” by the guideline below.
464
+ More steps mean that the edit effect is more pronounced.
465
+ - If the editing is super easy, like changing something to something with very similar features, you can choose 8 steps.
466
+ - In most cases, you can choose 15 steps.
467
+ - For style editing and remove tasks, you can choose a larger value, like 25 steps.
468
+ - If you feel the task is extremely difficult (like some kinds of styles or removing very tiny stuffs), you can directly use 50 steps.
469
+ """
470
+ def get_params(api_key, user_instruct):
471
+ from openai import OpenAI
472
+ client = OpenAI(api_key=api_key)
473
+ print("user_instruct", user_instruct)
474
+ response = client.chat.completions.create(
475
+ model="gpt-4-1106-preview",
476
+ messages=[
477
+ {"role": "system", "content": param_bot_prompt},
478
+ {"role": "user", "content": user_instruct}
479
+ ],
480
+ response_format={ "type": "json_object" },
481
+ )
482
+ param_dict = response.choices[0].message.content
483
+ print("param_dict", param_dict)
484
+ import json
485
+ param_dict = json.loads(param_dict)
486
+ return param_dict['source_prompt'], param_dict['target_prompt'], param_dict['target_sub'], param_dict['mutual_sub'], param_dict['attention_control'], param_dict['steps']
487
+ with gr.Blocks(css=css) as demo:
488
+ gr.HTML(intro)
489
+ with gr.Accordion("README", open=False):
490
+ gr.HTML(
491
+ """
492
+ <p style="font-size: 0.95rem;margin: 0rem;line-height: 1.2em;margin-top:1em;display: inline-block">
493
+ <a href="https://sled-group.github.io/InfEdit/" target="_blank">project page</a> | <a href="https://arxiv.org" target="_blank">paper</a>| <a href="https://github.com/sled-group/InfEdit/tree/website" target="_blank">handbook</a>
494
+ </p>
495
+
496
+ We are now hosting on a A4000 GPU with 16 GiB memory.
497
+ """
498
+ )
499
+ with gr.Row():
500
+
501
+ with gr.Column(scale=55):
502
+ with gr.Group():
503
+
504
+ img = gr.Image(label="Input image", height=512, type="pil")
505
+
506
+ image_out = gr.Image(label="Output image", height=512)
507
+ # gallery = gr.Gallery(
508
+ # label="Generated images", show_label=False, elem_id="gallery"
509
+ # ).style(grid=[1], height="auto")
510
+
511
+ with gr.Column(scale=45):
512
+
513
+ with gr.Tab("UAC options"):
514
+ with gr.Group():
515
+ with gr.Row():
516
+ source_prompt = gr.Textbox(label="Source prompt", placeholder="Source prompt describes the input image")
517
+ with gr.Row():
518
+ guidance_s = gr.Slider(label="Source guidance scale", value=1, minimum=1, maximum=10)
519
+ positive_prompt = gr.Textbox(label="Positive prompt", placeholder="")
520
+ with gr.Row():
521
+ target_prompt = gr.Textbox(label="Target prompt", placeholder="Target prompt describes the output image")
522
+ with gr.Row():
523
+ guidance_t = gr.Slider(label="Target guidance scale", value=2, minimum=1, maximum=10)
524
+ negative_prompt = gr.Textbox(label="Negative prompt", placeholder="")
525
+ with gr.Row():
526
+ local = gr.Textbox(label="Target blend", placeholder="")
527
+ thresh_e = gr.Slider(label="Target blend thresh", value=0.6, minimum=0, maximum=1)
528
+ with gr.Row():
529
+ mutual = gr.Textbox(label="Source blend", placeholder="")
530
+ thresh_m = gr.Slider(label="Source blend thresh", value=0.6, minimum=0, maximum=1)
531
+ with gr.Row():
532
+ cross_replace_steps = gr.Slider(label="Cross attn control schedule", value=0.7, minimum=0.0, maximum=1, step=0.01)
533
+ self_replace_steps = gr.Slider(label="Self attn control schedule", value=0.3, minimum=0.0, maximum=1, step=0.01)
534
+ with gr.Row():
535
+ denoise = gr.Checkbox(label='Denoising Mode', value=False)
536
+ strength = gr.Slider(label="Strength", value=0.7, minimum=0, maximum=1, step=0.01, visible=False)
537
+ denoise.change(fn=lambda value: gr.update(visible=value), inputs=denoise, outputs=strength)
538
+ with gr.Row():
539
+ generate1 = gr.Button(value="Run")
540
+
541
+ with gr.Tab("Advanced options"):
542
+ with gr.Group():
543
+ with gr.Row():
544
+ num_inference_steps = gr.Slider(label="Inference steps", value=15, minimum=1, maximum=50, step=1)
545
+ width = gr.Slider(label="Width", value=512, minimum=512, maximum=1024, step=8)
546
+ height = gr.Slider(label="Height", value=512, minimum=512, maximum=1024, step=8)
547
+ with gr.Row():
548
+ seed = gr.Slider(0, 2147483647, label='Seed', value=0, step=1)
549
+ with gr.Row():
550
+ generate3 = gr.Button(value="Run")
551
+
552
+ with gr.Tab("Instruction following (+GPT4)"):
553
+ guide_str = """Describe the image you uploaded and tell me how you want to edit it."""
554
+ with gr.Group():
555
+ api_key = gr.Textbox(label="YOUR OPENAI API KEY", placeholder="sk-xxx", lines = 1, type="password")
556
+ user_instruct = gr.Textbox(label=guide_str, placeholder="The image shows an apple on the table and I want to change the apple to a banana.", lines = 3)
557
+ # source_prompt, target_prompt, local, mutual = get_params(api_key, user_instruct)
558
+ with gr.Row():
559
+ generate4 = gr.Button(value="Run")
560
+
561
+ inputs1 = [img, source_prompt, target_prompt,
562
+ local, mutual,
563
+ positive_prompt, negative_prompt,
564
+ guidance_s, guidance_t,
565
+ num_inference_steps,
566
+ width, height, seed, strength,
567
+ cross_replace_steps, self_replace_steps,
568
+ thresh_e, thresh_m, denoise]
569
+ inputs4 =[img, source_prompt, target_prompt,
570
+ local, mutual,
571
+ positive_prompt, negative_prompt,
572
+ guidance_s, guidance_t,
573
+ num_inference_steps,
574
+ width, height, seed, strength,
575
+ cross_replace_steps, self_replace_steps,
576
+ thresh_e, thresh_m, denoise, user_instruct, api_key]
577
+ generate1.click(inference, inputs=inputs1, outputs=image_out)
578
+ generate3.click(inference, inputs=inputs1, outputs=image_out)
579
+ generate4.click(inference, inputs=inputs4, outputs=image_out)
580
+
581
+ ex = gr.Examples(
582
+ [
583
+ ["images/corgi.jpg","corgi","cat","cat","","","",1,2,15,512,512,0,1,0.7,0.7,0.6,0.6,False],
584
+ ["images/muffin.png","muffin","chihuahua","chihuahua","","","",1,2,15,512,512,0,1,0.65,0.6,0.4,0.7,False],
585
+ ["images/InfEdit.jpg","an anime girl holding a pad","an anime girl holding a book","book","girl ","","",1,2,15,512,512,0,1,0.8,0.8,0.6,0.6,False],
586
+ ["images/summer.jpg","a photo of summer scene","A photo of winter scene","","","","",1,2,15,512,512,0,1,1,1,0.6,0.7,False],
587
+ ["images/bear.jpg","A bear sitting on the ground","A bear standing on the ground","bear","","","",1,1.5,15,512,512,0,1,0.3,0.3,0.5,0.7,False],
588
+ ["images/james.jpg","a man playing basketball","a man playing soccer","soccer","man ","","",1,2,15,512,512,0,1,0,0,0.5,0.4,False],
589
+ ["images/osu.jfif","A football with OSU logo","A football with Umich logo","logo","","","",1,2,15,512,512,0,1,0.5,0,0.6,0.7,False],
590
+ ["images/groundhog.png","A anime groundhog head","A anime ferret head","head","","","",1,2,15,512,512,0,1,0.5,0.5,0.6,0.7,False],
591
+ ["images/miku.png","A anime girl with green hair and green eyes and shirt","A anime girl with red hair and red eyes and shirt","red hair and red eyes","shirt","","",1,2,15,512,512,0,1,1,1,0.2,0.8,False],
592
+ ["images/droplet.png","a blue droplet emoji with a smiling face with yellow dot","a red fire emoji with an angry face with yellow dot","","yellow dot","","",1,2,15,512,512,0,1,0.7,0.7,0.6,0.7,False],
593
+ ["images/moyu.png","an emoji holding a sign and a fish","an emoji holding a sign and a shark","shark","sign","","",1,2,15,512,512,0,1,0.7,0.7,0.5,0.7,False],
594
+ ["images/214000000000.jpg","a painting of a waterfall in the mountains","a painting of a waterfall and angels in the mountains","angels","","","",1,2,15,512,512,0,1,0,0,0.5,0.5,False],
595
+ ["images/311000000002.jpg","a lion in a suit sitting at a table with a laptop","a lion in a suit sitting at a table with nothing","nothing","","","",1,2,15,512,512,0,1,0,0,0.5,0.5,False],
596
+ ["images/genshin.png","anime girl, with blue logo","anime boy with golden hair named Link, from The Legend of Zelda, with legend of zelda logo","anime boy","","","",1,2,50,512,512,0,1,0.65,0.65,0.5,0.5,False],
597
+ ["images/angry.jpg","a man with bounding boxes at the door","a man with angry birds at the door","angry birds","a man","","",1,2,15,512,512,0,1,0.3,0.1,0.45,0.4,False],
598
+ ["images/Doom_Slayer.jpg","doom slayer from game doom","master chief from game halo","","","","",1,2,15,512,512,0,1,0.6,0.8,0.7,0.7,False],
599
+ ["images/Elon_Musk.webp","Elon Musk in front of a car","Mark Iv iron man suit in front of a car","Mark Iv iron man suit","car","","",1,2,15,512,512,0,1,0.5,0.3,0.6,0.7,False],
600
+ ["images/dragon.jpg","a mascot dragon","pixel art, a mascot dragon","","","","",1,2,25,512,512,0,1,0.7,0.7,0.6,0.6,False],
601
+ ["images/frieren.jpg","a anime girl with long white hair holding a bottle","a anime girl with long white hair holding a smartphone","smartphone","","","",1,2,15,512,512,0,1,0.7,0.7,0.7,0.7,False],
602
+ ["images/sam.png","a man with an openai logo","a man with a twitter logo","a twitter logo","a man","","",1,2,15,512,512,0,0.8,0,0,0.3,0.6,True],
603
+
604
+
605
+ ],
606
+ [img, source_prompt, target_prompt,
607
+ local, mutual,
608
+ positive_prompt, negative_prompt,
609
+ guidance_s, guidance_t,
610
+ num_inference_steps,
611
+ width, height, seed, strength,
612
+ cross_replace_steps, self_replace_steps,
613
+ thresh_e, thresh_m, denoise],
614
+ image_out, inference, cache_examples=True,examples_per_page=20)
615
+ # if not is_colab:
616
+ # demo.queue(concurrency_count=1)
617
+
618
+ # demo.launch(debug=False, share=False,server_name="0.0.0.0",server_port = 80)
619
+ demo.launch(debug=False, share=False)
620
+
images/214000000000.jpg ADDED
images/311000000002.jpg ADDED
images/Doom_Slayer.jpg ADDED
images/Elon_Musk.webp ADDED
images/InfEdit.jpg ADDED

Git LFS Details

  • SHA256: ff5d2c81b8a5fe77a95385ecf79356e9a1d204f2c5837e42d928ee4e255c4abc
  • Pointer size: 132 Bytes
  • Size of remote file: 1.31 MB
images/angry.jpg ADDED
images/bear.jpg ADDED
images/computer.png ADDED
images/corgi.jpg ADDED
images/dragon.jpg ADDED
images/droplet.png ADDED
images/frieren.jpg ADDED
images/genshin.png ADDED
images/groundhog.png ADDED
images/james.jpg ADDED
images/miku.png ADDED
images/moyu.png ADDED
images/muffin.png ADDED
images/osu.jfif ADDED
Binary file (4.38 kB). View file
 
images/sam.png ADDED

Git LFS Details

  • SHA256: e149ce2aba40d726168884c53ebb41cfb33b4b188f349fdb94fa07d1da28b74f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.23 MB
images/summer.jpg ADDED
nsfw.png ADDED
pipeline_ead.py ADDED
@@ -0,0 +1,707 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Any, Callable, Dict, List, Optional, Union
3
+
4
+ import numpy as np
5
+ import PIL
6
+ import torch
7
+ from packaging import version
8
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
9
+
10
+ from diffusers.configuration_utils import FrozenDict
11
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
12
+ from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
13
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
14
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
15
+ from diffusers.schedulers import LCMScheduler
16
+ from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
17
+ from diffusers.utils.torch_utils import randn_tensor
18
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
19
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
20
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
21
+
22
+
23
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
24
+
25
+
26
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
27
+ def preprocess(image):
28
+ deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
29
+ deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False)
30
+ if isinstance(image, torch.Tensor):
31
+ return image
32
+ elif isinstance(image, PIL.Image.Image):
33
+ image = [image]
34
+
35
+ if isinstance(image[0], PIL.Image.Image):
36
+ w, h = image[0].size
37
+ w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
38
+
39
+ image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
40
+ image = np.concatenate(image, axis=0)
41
+ image = np.array(image).astype(np.float32) / 255.0
42
+ image = image.transpose(0, 3, 1, 2)
43
+ image = 2.0 * image - 1.0
44
+ image = torch.from_numpy(image)
45
+ elif isinstance(image[0], torch.Tensor):
46
+ image = torch.cat(image, dim=0)
47
+ return image
48
+
49
+
50
+ def ddcm_sampler(scheduler, x_s, x_t, timestep, e_s, e_t, x_0, noise, eta, to_next=True):
51
+ if scheduler.num_inference_steps is None:
52
+ raise ValueError(
53
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
54
+ )
55
+
56
+ if scheduler.step_index is None:
57
+ scheduler._init_step_index(timestep)
58
+
59
+ prev_step_index = scheduler.step_index + 1
60
+ if prev_step_index < len(scheduler.timesteps):
61
+ prev_timestep = scheduler.timesteps[prev_step_index]
62
+ else:
63
+ prev_timestep = timestep
64
+
65
+ alpha_prod_t = scheduler.alphas_cumprod[timestep]
66
+ alpha_prod_t_prev = (
67
+ scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.final_alpha_cumprod
68
+ )
69
+ beta_prod_t = 1 - alpha_prod_t
70
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
71
+ variance = beta_prod_t_prev
72
+ std_dev_t = eta * variance
73
+ noise = std_dev_t ** (0.5) * noise
74
+
75
+ e_c = (x_s - alpha_prod_t ** (0.5) * x_0) / (1 - alpha_prod_t) ** (0.5)
76
+
77
+ pred_x0 = x_0 + ((x_t - x_s) - beta_prod_t ** (0.5) * (e_t - e_s)) / alpha_prod_t ** (0.5)
78
+ eps = (e_t - e_s) + e_c
79
+ dir_xt = (beta_prod_t_prev - std_dev_t) ** (0.5) * eps
80
+
81
+ # Noise is not used for one-step sampling.
82
+ if len(scheduler.timesteps) > 1:
83
+ prev_xt = alpha_prod_t_prev ** (0.5) * pred_x0 + dir_xt + noise
84
+ prev_xs = alpha_prod_t_prev ** (0.5) * x_0 + dir_xt + noise
85
+ else:
86
+ prev_xt = pred_x0
87
+ prev_xs = x_0
88
+
89
+ if to_next:
90
+ scheduler._step_index += 1
91
+ return prev_xs, prev_xt, pred_x0
92
+
93
+
94
+ class EditPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
95
+ model_cpu_offload_seq = "text_encoder->unet->vae"
96
+ _optional_components = ["safety_checker", "feature_extractor"]
97
+
98
+ def __init__(
99
+ self,
100
+ vae: AutoencoderKL,
101
+ text_encoder: CLIPTextModel,
102
+ tokenizer: CLIPTokenizer,
103
+ unet: UNet2DConditionModel,
104
+ scheduler: LCMScheduler,
105
+ safety_checker: StableDiffusionSafetyChecker,
106
+ feature_extractor: CLIPImageProcessor,
107
+ requires_safety_checker: bool = True,
108
+ ):
109
+ super().__init__()
110
+
111
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
112
+ deprecation_message = (
113
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
114
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
115
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
116
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
117
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
118
+ " file"
119
+ )
120
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
121
+ new_config = dict(scheduler.config)
122
+ new_config["steps_offset"] = 1
123
+ scheduler._internal_dict = FrozenDict(new_config)
124
+
125
+ if safety_checker is None and requires_safety_checker:
126
+ logger.warning(
127
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
128
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
129
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
130
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
131
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
132
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
133
+ )
134
+
135
+ if safety_checker is not None and feature_extractor is None:
136
+ raise ValueError(
137
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
138
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
139
+ )
140
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
141
+ version.parse(unet.config._diffusers_version).base_version
142
+ ) < version.parse("0.9.0.dev0")
143
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
144
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
145
+ deprecation_message = (
146
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
147
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
148
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
149
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
150
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
151
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
152
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
153
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
154
+ " the `unet/config.json` file"
155
+ )
156
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
157
+ new_config = dict(unet.config)
158
+ new_config["sample_size"] = 64
159
+ unet._internal_dict = FrozenDict(new_config)
160
+
161
+ self.register_modules(
162
+ vae=vae,
163
+ text_encoder=text_encoder,
164
+ tokenizer=tokenizer,
165
+ unet=unet,
166
+ scheduler=scheduler,
167
+ safety_checker=safety_checker,
168
+ feature_extractor=feature_extractor,
169
+ )
170
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
171
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
172
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
173
+
174
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
175
+ def _encode_prompt(
176
+ self,
177
+ prompt,
178
+ device,
179
+ num_images_per_prompt,
180
+ do_classifier_free_guidance,
181
+ negative_prompt=None,
182
+ prompt_embeds: Optional[torch.FloatTensor] = None,
183
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
184
+ lora_scale: Optional[float] = None,
185
+ ):
186
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
187
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
188
+
189
+ prompt_embeds_tuple = self.encode_prompt(
190
+ prompt=prompt,
191
+ device=device,
192
+ num_images_per_prompt=num_images_per_prompt,
193
+ do_classifier_free_guidance=do_classifier_free_guidance,
194
+ negative_prompt=negative_prompt,
195
+ prompt_embeds=prompt_embeds,
196
+ negative_prompt_embeds=negative_prompt_embeds,
197
+ lora_scale=lora_scale,
198
+ )
199
+
200
+ # concatenate for backwards comp
201
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
202
+
203
+ return prompt_embeds
204
+
205
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
206
+ def encode_prompt(
207
+ self,
208
+ prompt,
209
+ device,
210
+ num_images_per_prompt,
211
+ do_classifier_free_guidance,
212
+ negative_prompt=None,
213
+ prompt_embeds: Optional[torch.FloatTensor] = None,
214
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
215
+ lora_scale: Optional[float] = None,
216
+ ):
217
+ # set lora scale so that monkey patched LoRA
218
+ # function of text encoder can correctly access it
219
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
220
+ self._lora_scale = lora_scale
221
+
222
+ # dynamically adjust the LoRA scale
223
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
224
+
225
+ if prompt is not None and isinstance(prompt, str):
226
+ batch_size = 1
227
+ elif prompt is not None and isinstance(prompt, list):
228
+ batch_size = len(prompt)
229
+ else:
230
+ batch_size = prompt_embeds.shape[0]
231
+
232
+ if prompt_embeds is None:
233
+ # textual inversion: procecss multi-vector tokens if necessary
234
+ if isinstance(self, TextualInversionLoaderMixin):
235
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
236
+
237
+ text_inputs = self.tokenizer(
238
+ prompt,
239
+ padding="max_length",
240
+ max_length=self.tokenizer.model_max_length,
241
+ truncation=True,
242
+ return_tensors="pt",
243
+ )
244
+ text_input_ids = text_inputs.input_ids
245
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
246
+
247
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
248
+ text_input_ids, untruncated_ids
249
+ ):
250
+ removed_text = self.tokenizer.batch_decode(
251
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
252
+ )
253
+ logger.warning(
254
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
255
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
256
+ )
257
+
258
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
259
+ attention_mask = text_inputs.attention_mask.to(device)
260
+ else:
261
+ attention_mask = None
262
+
263
+ prompt_embeds = self.text_encoder(
264
+ text_input_ids.to(device),
265
+ attention_mask=attention_mask,
266
+ )
267
+ prompt_embeds = prompt_embeds[0]
268
+
269
+ if self.text_encoder is not None:
270
+ prompt_embeds_dtype = self.text_encoder.dtype
271
+ elif self.unet is not None:
272
+ prompt_embeds_dtype = self.unet.dtype
273
+ else:
274
+ prompt_embeds_dtype = prompt_embeds.dtype
275
+
276
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
277
+
278
+ bs_embed, seq_len, _ = prompt_embeds.shape
279
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
280
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
281
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
282
+
283
+ # get unconditional embeddings for classifier free guidance
284
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
285
+ uncond_tokens: List[str]
286
+ if negative_prompt is None:
287
+ uncond_tokens = [""] * batch_size
288
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
289
+ raise TypeError(
290
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
291
+ f" {type(prompt)}."
292
+ )
293
+ elif isinstance(negative_prompt, str):
294
+ uncond_tokens = [negative_prompt]
295
+ elif batch_size != len(negative_prompt):
296
+ raise ValueError(
297
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
298
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
299
+ " the batch size of `prompt`."
300
+ )
301
+ else:
302
+ uncond_tokens = negative_prompt
303
+
304
+ # textual inversion: procecss multi-vector tokens if necessary
305
+ if isinstance(self, TextualInversionLoaderMixin):
306
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
307
+
308
+ max_length = prompt_embeds.shape[1]
309
+ uncond_input = self.tokenizer(
310
+ uncond_tokens,
311
+ padding="max_length",
312
+ max_length=max_length,
313
+ truncation=True,
314
+ return_tensors="pt",
315
+ )
316
+
317
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
318
+ attention_mask = uncond_input.attention_mask.to(device)
319
+ else:
320
+ attention_mask = None
321
+
322
+ negative_prompt_embeds = self.text_encoder(
323
+ uncond_input.input_ids.to(device),
324
+ attention_mask=attention_mask,
325
+ )
326
+ negative_prompt_embeds = negative_prompt_embeds[0]
327
+
328
+ if do_classifier_free_guidance:
329
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
330
+ seq_len = negative_prompt_embeds.shape[1]
331
+
332
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
333
+
334
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
335
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
336
+
337
+ return prompt_embeds, negative_prompt_embeds
338
+
339
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.check_inputs
340
+ def check_inputs(
341
+ self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None
342
+ ):
343
+ if strength < 0 or strength > 1:
344
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
345
+
346
+ if (callback_steps is None) or (
347
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
348
+ ):
349
+ raise ValueError(
350
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
351
+ f" {type(callback_steps)}."
352
+ )
353
+
354
+ if prompt is not None and prompt_embeds is not None:
355
+ raise ValueError(
356
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
357
+ " only forward one of the two."
358
+ )
359
+ elif prompt is None and prompt_embeds is None:
360
+ raise ValueError(
361
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
362
+ )
363
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
364
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
365
+
366
+ if negative_prompt is not None and negative_prompt_embeds is not None:
367
+ raise ValueError(
368
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
369
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
370
+ )
371
+
372
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
373
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
374
+ raise ValueError(
375
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
376
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
377
+ f" {negative_prompt_embeds.shape}."
378
+ )
379
+
380
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
381
+ def prepare_extra_step_kwargs(self, generator, eta):
382
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
383
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
384
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
385
+ # and should be between [0, 1]
386
+
387
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
388
+ extra_step_kwargs = {}
389
+ if accepts_eta:
390
+ extra_step_kwargs["eta"] = eta
391
+
392
+ # check if the scheduler accepts generator
393
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
394
+ if accepts_generator:
395
+ extra_step_kwargs["generator"] = generator
396
+ return extra_step_kwargs
397
+
398
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
399
+ def run_safety_checker(self, image, device, dtype):
400
+ if self.safety_checker is None:
401
+ has_nsfw_concept = None
402
+ else:
403
+ if torch.is_tensor(image):
404
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
405
+ else:
406
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
407
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
408
+ image, has_nsfw_concept = self.safety_checker(
409
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
410
+ )
411
+ return image, has_nsfw_concept
412
+
413
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
414
+ def decode_latents(self, latents):
415
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
416
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
417
+
418
+ latents = 1 / self.vae.config.scaling_factor * latents
419
+ image = self.vae.decode(latents, return_dict=False)[0]
420
+ image = (image / 2 + 0.5).clamp(0, 1)
421
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
422
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
423
+ return image
424
+
425
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
426
+ def get_timesteps(self, num_inference_steps, strength, device):
427
+ # get the original timestep using init_timestep
428
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
429
+
430
+ t_start = max(num_inference_steps - init_timestep, 0)
431
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
432
+
433
+ return timesteps, num_inference_steps - t_start
434
+
435
+ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, denoise_model, generator=None):
436
+ image = image.to(device=device, dtype=dtype)
437
+
438
+ batch_size = image.shape[0]
439
+
440
+ if image.shape[1] == 4:
441
+ init_latents = image
442
+
443
+ else:
444
+ if isinstance(generator, list) and len(generator) != batch_size:
445
+ raise ValueError(
446
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
447
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
448
+ )
449
+
450
+ if isinstance(generator, list):
451
+ init_latents = [
452
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
453
+ ]
454
+ init_latents = torch.cat(init_latents, dim=0)
455
+ else:
456
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
457
+
458
+ init_latents = self.vae.config.scaling_factor * init_latents
459
+
460
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
461
+ # expand init_latents for batch_size
462
+ deprecation_message = (
463
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
464
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
465
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
466
+ " your script to pass as many initial images as text prompts to suppress this warning."
467
+ )
468
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
469
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
470
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt * num_images_per_prompt, dim=0)
471
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
472
+ raise ValueError(
473
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
474
+ )
475
+ else:
476
+ init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0)
477
+
478
+ # add noise to latents using the timestep
479
+ shape = init_latents.shape
480
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
481
+
482
+ # get latents
483
+ clean_latents = init_latents
484
+ if denoise_model:
485
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
486
+ latents = init_latents
487
+ else:
488
+ latents = noise
489
+
490
+ return latents, clean_latents
491
+
492
+ @torch.no_grad()
493
+ def __call__(
494
+ self,
495
+ prompt: Union[str, List[str]],
496
+ source_prompt: Union[str, List[str]],
497
+ negative_prompt: Union[str, List[str]]=None,
498
+ positive_prompt: Union[str, List[str]]=None,
499
+ image: PipelineImageInput = None,
500
+ strength: float = 0.8,
501
+ num_inference_steps: Optional[int] = 50,
502
+ original_inference_steps: Optional[int] = 50,
503
+ guidance_scale: Optional[float] = 7.5,
504
+ source_guidance_scale: Optional[float] = 1,
505
+ num_images_per_prompt: Optional[int] = 1,
506
+ eta: Optional[float] = 1.0,
507
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
508
+ prompt_embeds: Optional[torch.FloatTensor] = None,
509
+ output_type: Optional[str] = "pil",
510
+ return_dict: bool = True,
511
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
512
+ callback_steps: int = 1,
513
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
514
+ denoise_model: Optional[bool] = True,
515
+ ):
516
+ # 1. Check inputs
517
+ self.check_inputs(prompt, strength, callback_steps)
518
+
519
+ # 2. Define call parameters
520
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
521
+ device = self._execution_device
522
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
523
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
524
+ # corresponds to doing no classifier free guidance.
525
+ do_classifier_free_guidance = guidance_scale > 1.0
526
+
527
+ # 3. Encode input prompt
528
+ text_encoder_lora_scale = (
529
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
530
+ )
531
+ prompt_embeds_tuple = self.encode_prompt(
532
+ prompt,
533
+ device,
534
+ num_images_per_prompt,
535
+ do_classifier_free_guidance,
536
+ negative_prompt=negative_prompt,
537
+ prompt_embeds=prompt_embeds,
538
+ lora_scale=text_encoder_lora_scale,
539
+ )
540
+ source_prompt_embeds_tuple = self.encode_prompt(
541
+ source_prompt, device, num_images_per_prompt, do_classifier_free_guidance, positive_prompt, None
542
+ )
543
+ if prompt_embeds_tuple[1] is not None:
544
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
545
+ else:
546
+ prompt_embeds = prompt_embeds_tuple[0]
547
+ if source_prompt_embeds_tuple[1] is not None:
548
+ source_prompt_embeds = torch.cat([source_prompt_embeds_tuple[1], source_prompt_embeds_tuple[0]])
549
+ else:
550
+ source_prompt_embeds = source_prompt_embeds_tuple[0]
551
+
552
+ # 4. Preprocess image
553
+ image = self.image_processor.preprocess(image)
554
+
555
+ # 5. Prepare timesteps
556
+ self.scheduler.set_timesteps(
557
+ num_inference_steps=num_inference_steps,
558
+ device=device,
559
+ original_inference_steps=original_inference_steps)
560
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
561
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
562
+
563
+ # 6. Prepare latent variables
564
+ latents, clean_latents = self.prepare_latents(
565
+ image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, denoise_model, generator
566
+ )
567
+ source_latents = latents
568
+ mutual_latents = latents
569
+
570
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
571
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
572
+ generator = extra_step_kwargs.pop("generator", None)
573
+
574
+ # 8. Denoising loop
575
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
576
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
577
+ for i, t in enumerate(timesteps):
578
+ # expand the latents if we are doing classifier free guidance
579
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
580
+ source_latent_model_input = (
581
+ torch.cat([source_latents] * 2) if do_classifier_free_guidance else source_latents
582
+ )
583
+ mutual_latent_model_input = (
584
+ torch.cat([mutual_latents] * 2) if do_classifier_free_guidance else mutual_latents
585
+ )
586
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
587
+ source_latent_model_input = self.scheduler.scale_model_input(source_latent_model_input, t)
588
+ mutual_latent_model_input = self.scheduler.scale_model_input(mutual_latent_model_input, t)
589
+
590
+ # predict the noise residual
591
+ if do_classifier_free_guidance:
592
+ concat_latent_model_input = torch.stack(
593
+ [
594
+ source_latent_model_input[0],
595
+ latent_model_input[0],
596
+ mutual_latent_model_input[0],
597
+ source_latent_model_input[1],
598
+ latent_model_input[1],
599
+ mutual_latent_model_input[1],
600
+ ],
601
+ dim=0,
602
+ )
603
+ concat_prompt_embeds = torch.stack(
604
+ [
605
+ source_prompt_embeds[0],
606
+ prompt_embeds[0],
607
+ source_prompt_embeds[0],
608
+ source_prompt_embeds[1],
609
+ prompt_embeds[1],
610
+ source_prompt_embeds[1],
611
+ ],
612
+ dim=0,
613
+ )
614
+ else:
615
+ concat_latent_model_input = torch.cat(
616
+ [
617
+ source_latent_model_input,
618
+ latent_model_input,
619
+ mutual_latent_model_input,
620
+ ],
621
+ dim=0,
622
+ )
623
+ concat_prompt_embeds = torch.cat(
624
+ [
625
+ source_prompt_embeds,
626
+ prompt_embeds,
627
+ source_prompt_embeds,
628
+ ],
629
+ dim=0,
630
+ )
631
+
632
+ concat_noise_pred = self.unet(
633
+ concat_latent_model_input,
634
+ t,
635
+ cross_attention_kwargs=cross_attention_kwargs,
636
+ encoder_hidden_states=concat_prompt_embeds,
637
+ ).sample
638
+
639
+ # perform guidance
640
+ if do_classifier_free_guidance:
641
+ (
642
+ source_noise_pred_uncond,
643
+ noise_pred_uncond,
644
+ mutual_noise_pred_uncond,
645
+ source_noise_pred_text,
646
+ noise_pred_text,
647
+ mutual_noise_pred_text
648
+ ) = concat_noise_pred.chunk(6, dim=0)
649
+
650
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
651
+ source_noise_pred = source_noise_pred_uncond + source_guidance_scale * (
652
+ source_noise_pred_text - source_noise_pred_uncond
653
+ )
654
+ mutual_noise_pred = mutual_noise_pred_uncond + source_guidance_scale * (
655
+ mutual_noise_pred_text - mutual_noise_pred_uncond
656
+ )
657
+
658
+ else:
659
+ (source_noise_pred, noise_pred, mutual_noise_pred) = concat_noise_pred.chunk(3, dim=0)
660
+
661
+ noise = torch.randn(
662
+ latents.shape, dtype=latents.dtype, device=latents.device, generator=generator
663
+ )
664
+
665
+ _, latents, pred_x0 = ddcm_sampler(
666
+ self.scheduler, source_latents,
667
+ latents, t,
668
+ source_noise_pred, noise_pred,
669
+ clean_latents, noise=noise,
670
+ eta=eta, to_next=False,
671
+ **extra_step_kwargs
672
+ )
673
+
674
+ source_latents, mutual_latents, pred_xm = ddcm_sampler(
675
+ self.scheduler, source_latents,
676
+ mutual_latents, t,
677
+ source_noise_pred, mutual_noise_pred,
678
+ clean_latents, noise=noise,
679
+ eta=eta, **extra_step_kwargs
680
+ )
681
+
682
+ # call the callback, if provided
683
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
684
+ progress_bar.update()
685
+ if callback is not None and i % callback_steps == 0:
686
+ alpha_prod_t = self.scheduler.alphas_cumprod[t]
687
+ mutual_latents, latents = callback(i, t, source_latents, latents, mutual_latents, alpha_prod_t)
688
+
689
+ # 9. Post-processing
690
+ if not output_type == "latent":
691
+ image = self.vae.decode(pred_x0 / self.vae.config.scaling_factor, return_dict=False)[0]
692
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
693
+ else:
694
+ image = pred_x0
695
+ has_nsfw_concept = None
696
+
697
+ if has_nsfw_concept is None:
698
+ do_denormalize = [True] * image.shape[0]
699
+ else:
700
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
701
+
702
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
703
+
704
+ if not return_dict:
705
+ return (image, has_nsfw_concept)
706
+
707
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
ptp_utils.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ import torch
17
+ from typing import Optional, Union, Tuple, Dict
18
+ from PIL import Image
19
+
20
+ def save_images(images,dest, num_rows=1, offset_ratio=0.02):
21
+ if type(images) is list:
22
+ num_empty = len(images) % num_rows
23
+ elif images.ndim == 4:
24
+ num_empty = images.shape[0] % num_rows
25
+ else:
26
+ images = [images]
27
+ num_empty = 0
28
+
29
+ pil_img = Image.fromarray(images[-1])
30
+ pil_img.save(dest)
31
+ # display(pil_img)
32
+
33
+
34
+ def save_image(images,dest, num_rows=1, offset_ratio=0.02):
35
+ print(images.shape)
36
+ pil_img = Image.fromarray(images[0])
37
+ pil_img.save(dest)
38
+
39
+ def register_attention_control(model, controller):
40
+ class AttnProcessor():
41
+ def __init__(self,place_in_unet):
42
+ self.place_in_unet = place_in_unet
43
+
44
+ def __call__(self,
45
+ attn,
46
+ hidden_states,
47
+ encoder_hidden_states=None,
48
+ attention_mask=None,
49
+ temb=None,
50
+ scale=1.0,):
51
+ # The `Attention` class can call different attention processors / attention functions
52
+
53
+ residual = hidden_states
54
+
55
+ if attn.spatial_norm is not None:
56
+ hidden_states = attn.spatial_norm(hidden_states, temb)
57
+
58
+ input_ndim = hidden_states.ndim
59
+
60
+ if input_ndim == 4:
61
+ batch_size, channel, height, width = hidden_states.shape
62
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
63
+
64
+ h = attn.heads
65
+ is_cross = encoder_hidden_states is not None
66
+ if encoder_hidden_states is None:
67
+ encoder_hidden_states = hidden_states
68
+ elif attn.norm_cross:
69
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
70
+
71
+ batch_size, sequence_length, _ = (
72
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
73
+ )
74
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
75
+
76
+ q = attn.to_q(hidden_states)
77
+ k = attn.to_k(encoder_hidden_states)
78
+ v = attn.to_v(encoder_hidden_states)
79
+ q = attn.head_to_batch_dim(q)
80
+ k = attn.head_to_batch_dim(k)
81
+ v = attn.head_to_batch_dim(v)
82
+
83
+ if not is_cross:
84
+ q,k,v = controller.self_attn_forward(q, k, v, attn.heads)
85
+
86
+ attention_probs = attn.get_attention_scores(q, k, attention_mask)
87
+ if is_cross:
88
+ attention_probs = controller(attention_probs , is_cross, self.place_in_unet)
89
+ # else:
90
+ # out = controller.self_attn_forward(q, k, v, sim, attention_probs , is_cross, self.place_in_unet, attn.heads, scale=attn.scale)
91
+ hidden_states = torch.bmm(attention_probs, v)
92
+ hidden_states = attn.batch_to_head_dim(hidden_states)
93
+
94
+ # linear proj
95
+ hidden_states = attn.to_out[0](hidden_states, scale=scale)
96
+ # dropout
97
+ hidden_states = attn.to_out[1](hidden_states)
98
+
99
+ if input_ndim == 4:
100
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
101
+
102
+ if attn.residual_connection:
103
+ hidden_states = hidden_states + residual
104
+
105
+ hidden_states = hidden_states / attn.rescale_output_factor
106
+
107
+ return hidden_states
108
+
109
+
110
+ def register_recr(net_, count, place_in_unet):
111
+ for idx, m in enumerate(net_.modules()):
112
+ # print(m.__class__.__name__)
113
+ if m.__class__.__name__ == "Attention":
114
+ count+=1
115
+ m.processor = AttnProcessor( place_in_unet)
116
+ return count
117
+
118
+ cross_att_count = 0
119
+ sub_nets = model.unet.named_children()
120
+ for net in sub_nets:
121
+ if "down" in net[0]:
122
+ cross_att_count += register_recr(net[1], 0, "down")
123
+ elif "up" in net[0]:
124
+ cross_att_count += register_recr(net[1], 0, "up")
125
+ elif "mid" in net[0]:
126
+ cross_att_count += register_recr(net[1], 0, "mid")
127
+ controller.num_att_layers = cross_att_count
128
+
129
+
130
+ def get_word_inds(text: str, word_place: int, tokenizer):
131
+ split_text = text.split(" ")
132
+ if type(word_place) is str:
133
+ word_place = [i for i, word in enumerate(split_text) if word_place == word]
134
+ elif type(word_place) is int:
135
+ word_place = [word_place]
136
+ out = []
137
+ if len(word_place) > 0:
138
+ words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
139
+ cur_len, ptr = 0, 0
140
+
141
+ for i in range(len(words_encode)):
142
+ cur_len += len(words_encode[i])
143
+ if ptr in word_place:
144
+ out.append(i + 1)
145
+ if cur_len >= len(split_text[ptr]):
146
+ ptr += 1
147
+ cur_len = 0
148
+ return np.array(out)
149
+
150
+
151
+ def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, word_inds: Optional[torch.Tensor]=None):
152
+ if type(bounds) is float:
153
+ bounds = 0, bounds
154
+ start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
155
+ if word_inds is None:
156
+ word_inds = torch.arange(alpha.shape[2])
157
+ alpha[: start, prompt_ind, word_inds] = 0
158
+ alpha[start: end, prompt_ind, word_inds] = 1
159
+ alpha[end:, prompt_ind, word_inds] = 0
160
+ return alpha
161
+
162
+
163
+ def get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
164
+ tokenizer, max_num_words=77):
165
+ if type(cross_replace_steps) is not dict:
166
+ cross_replace_steps = {"default_": cross_replace_steps}
167
+ if "default_" not in cross_replace_steps:
168
+ cross_replace_steps["default_"] = (0., 1.)
169
+ alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
170
+ for i in range(len(prompts) - 1):
171
+ alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"],
172
+ i)
173
+ for key, item in cross_replace_steps.items():
174
+ if key != "default_":
175
+ inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
176
+ for i, ind in enumerate(inds):
177
+ if len(ind) > 0:
178
+ alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
179
+ alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words) # time, batch, heads, pixels, words
180
+ return alpha_time_words
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ torch
3
+ torchvision
4
+ git+https://github.com/huggingface/diffusers.git
5
+ Pillow
6
+ transformers
7
+ opencv-python
8
+ openai
seq_aligner.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import copy
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+
7
+ class ScoreParams:
8
+
9
+ def __init__(self, gap, match, mismatch):
10
+ self.gap = gap
11
+ self.match = match
12
+ self.mismatch = mismatch
13
+
14
+ def mis_match_char(self, x, y):
15
+ if x != y:
16
+ return self.mismatch
17
+ else:
18
+ return self.match
19
+
20
+
21
+ def get_matrix(size_x, size_y, gap):
22
+ matrix = []
23
+ for i in range(len(size_x) + 1):
24
+ sub_matrix = []
25
+ for j in range(len(size_y) + 1):
26
+ sub_matrix.append(0)
27
+ matrix.append(sub_matrix)
28
+ for j in range(1, len(size_y) + 1):
29
+ matrix[0][j] = j*gap
30
+ for i in range(1, len(size_x) + 1):
31
+ matrix[i][0] = i*gap
32
+ return matrix
33
+
34
+
35
+ def get_matrix(size_x, size_y, gap):
36
+ matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32)
37
+ matrix[0, 1:] = (np.arange(size_y) + 1) * gap
38
+ matrix[1:, 0] = (np.arange(size_x) + 1) * gap
39
+ return matrix
40
+
41
+
42
+ def get_traceback_matrix(size_x, size_y):
43
+ matrix = np.zeros((size_x + 1, size_y +1), dtype=np.int32)
44
+ matrix[0, 1:] = 1
45
+ matrix[1:, 0] = 2
46
+ matrix[0, 0] = 4
47
+ return matrix
48
+
49
+
50
+ def global_align(x, y, score):
51
+ matrix = get_matrix(len(x), len(y), score.gap)
52
+ trace_back = get_traceback_matrix(len(x), len(y))
53
+ for i in range(1, len(x) + 1):
54
+ for j in range(1, len(y) + 1):
55
+ left = matrix[i, j - 1] + score.gap
56
+ up = matrix[i - 1, j] + score.gap
57
+ diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1])
58
+ matrix[i, j] = max(left, up, diag)
59
+ if matrix[i, j] == left:
60
+ trace_back[i, j] = 1
61
+ elif matrix[i, j] == up:
62
+ trace_back[i, j] = 2
63
+ else:
64
+ trace_back[i, j] = 3
65
+ return matrix, trace_back
66
+
67
+
68
+ def get_aligned_sequences(x, y, trace_back):
69
+ x_seq = []
70
+ y_seq = []
71
+ i = len(x)
72
+ j = len(y)
73
+ mapper_y_to_x = []
74
+ while i > 0 or j > 0:
75
+ if trace_back[i, j] == 3:
76
+ x_seq.append(x[i-1])
77
+ y_seq.append(y[j-1])
78
+ i = i-1
79
+ j = j-1
80
+ mapper_y_to_x.append((j, i))
81
+ elif trace_back[i][j] == 1:
82
+ x_seq.append('-')
83
+ y_seq.append(y[j-1])
84
+ j = j-1
85
+ mapper_y_to_x.append((j, -1))
86
+ elif trace_back[i][j] == 2:
87
+ x_seq.append(x[i-1])
88
+ y_seq.append('-')
89
+ i = i-1
90
+ elif trace_back[i][j] == 4:
91
+ break
92
+ mapper_y_to_x.reverse()
93
+ return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64)
94
+
95
+
96
+ def get_mapper(x: str, y: str, specifier, tokenizer, encoder, device, max_len=77):
97
+ locol_prompt, mutual_prompt = specifier
98
+ x_seq = tokenizer.encode(x)
99
+ y_seq = tokenizer.encode(y)
100
+ e_seq = tokenizer.encode(locol_prompt)
101
+ m_seq = tokenizer.encode(mutual_prompt)
102
+ score = ScoreParams(0, 1, -1)
103
+ matrix, trace_back = global_align(x_seq, y_seq, score)
104
+ mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1]
105
+ alphas = torch.ones(max_len)
106
+ alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float()
107
+ mapper = torch.zeros(max_len, dtype=torch.int64)
108
+ mapper[:mapper_base.shape[0]] = mapper_base[:, 1]
109
+ mapper[mapper_base.shape[0]:] = len(y_seq) + torch.arange(max_len - len(y_seq))
110
+ m = copy.deepcopy(alphas)
111
+ alpha_e = torch.zeros_like(alphas)
112
+ alpha_m = torch.zeros_like(alphas)
113
+
114
+ # print("mapper of")
115
+ # print("<begin> "+x+" <end>")
116
+ # print("<begin> "+y+" <end>")
117
+ # print(mapper[:len(y_seq)])
118
+ # print(alphas[:len(y_seq)])
119
+
120
+ x = tokenizer(
121
+ x,
122
+ padding="max_length",
123
+ max_length=max_len,
124
+ truncation=True,
125
+ return_tensors="pt",
126
+ ).input_ids.to(device)
127
+ y = tokenizer(
128
+ y,
129
+ padding="max_length",
130
+ max_length=max_len,
131
+ truncation=True,
132
+ return_tensors="pt",
133
+ ).input_ids.to(device)
134
+
135
+ x_latent = encoder(x)[0].squeeze(0)
136
+ y_latent = encoder(y)[0].squeeze(0)
137
+ i = 0
138
+ while i<len(y_seq):
139
+ start = None
140
+ if alphas[i] == 0:
141
+ start = i
142
+ while alphas[i] == 0:
143
+ i += 1
144
+ max_sim = float('-inf')
145
+ max_s = None
146
+ max_t = None
147
+ for i_target in range(start, i):
148
+ for i_source in range(mapper[start-1]+1, mapper[i]):
149
+ sim = F.cosine_similarity(x_latent[i_target], y_latent[i_source], dim=0)
150
+ if sim > max_sim:
151
+ max_sim = sim
152
+ max_s = i_source
153
+ max_t = i_target
154
+ if max_s is not None:
155
+ mapper[max_t] = max_s
156
+ alphas[max_t] = 1
157
+ for t in e_seq:
158
+ if x_seq[max_s] == t:
159
+ alpha_e[max_t] = 1
160
+ i += 1
161
+
162
+ # replace_alpha, replace_mapper = get_replace_inds(x_seq, y_seq, m_seq, m_seq)
163
+ # if replace_mapper != []:
164
+ # mapper[replace_alpha]=torch.tensor(replace_mapper,device=mapper.device)
165
+ # alpha_m[replace_alpha]=1
166
+
167
+ i = 1
168
+ j = 1
169
+ while (i < len(y_seq)-1) and (j < len(e_seq)-1):
170
+ found = True
171
+ while e_seq[j] != y_seq[i]:
172
+ i = i + 1
173
+ if i >= len(y_seq)-1:
174
+ print("blend word not found!")
175
+ found = False
176
+ break
177
+ raise ValueError("local prompt not found in target prompt")
178
+ if found:
179
+ alpha_e[i] = 1
180
+ j = j + 1
181
+
182
+ i = 1
183
+ j = 1
184
+ while (i < len(y_seq)-1) and (j < len(m_seq)-1):
185
+ while m_seq[j] != y_seq[i]:
186
+ i = i + 1
187
+ if m_seq[j] == x_seq[mapper[i]]:
188
+ alpha_m[i] = 1
189
+ j = j + 1
190
+ else:
191
+ raise ValueError("mutual prompt not found in target prompt")
192
+
193
+ # print("fixed mapper:")
194
+ # print(mapper[:len(y_seq)])
195
+ # print(alphas[:len(y_seq)])
196
+ # print(m[:len(y_seq)])
197
+ # print(alpha_e[:len(y_seq)])
198
+ # print(alpha_m[:len(y_seq)])
199
+ return mapper, alphas, m, alpha_e, alpha_m
200
+
201
+
202
+ def get_refinement_mapper(prompts, specifiers, tokenizer, encoder, device, max_len=77):
203
+ x_seq = prompts[0]
204
+ mappers, alphas, ms, alpha_objs, alpha_descs = [], [], [], [], []
205
+ for i in range(1, len(prompts)):
206
+ mapper, alpha, m, alpha_obj, alpha_desc = get_mapper(x_seq, prompts[i], specifiers[i-1], tokenizer, encoder, device, max_len)
207
+ mappers.append(mapper)
208
+ alphas.append(alpha)
209
+ ms.append(m)
210
+ alpha_objs.append(alpha_obj)
211
+ alpha_descs.append(alpha_desc)
212
+ return torch.stack(mappers), torch.stack(alphas), torch.stack(ms), torch.stack(alpha_objs), torch.stack(alpha_descs)
213
+
214
+
215
+ def get_replace_inds(x_seq,y_seq,source_replace_seq,target_replace_seq):
216
+ replace_mapper=[]
217
+ replace_alpha=[]
218
+ source_found=False
219
+ source_match,target_match=[],[]
220
+ for j in range(len(x_seq)):
221
+ found=True
222
+ for i in range(1,len(source_replace_seq)-1):
223
+ if x_seq[j+i-1]!=source_replace_seq[i]:
224
+ found=False
225
+ break
226
+ if found:
227
+ source_found=True
228
+ for i in range(1,len(source_replace_seq)-1):
229
+ source_match.append(j+i-1)
230
+ for j in range(len(y_seq)):
231
+ found=True
232
+ for i in range(1,len(target_replace_seq)-1):
233
+ if y_seq[j+i-1]!=target_replace_seq[i]:
234
+ found=False
235
+ break
236
+ if found:
237
+ for i in range(1,len(source_replace_seq)-1):
238
+ target_match.append(j+i-1)
239
+ if not source_found:
240
+ raise ValueError("replacing object not found in prompt")
241
+ if (len(source_match)!=len(target_match)):
242
+ raise ValueError(f"the replacement word number doesn't match for word {i}!")
243
+ replace_alpha+=source_match
244
+ replace_mapper+=target_match
245
+ return replace_alpha,replace_mapper
246
+
247
+
248
+
249
+ def get_word_inds(text: str, word_place: int, tokenizer):
250
+ split_text = text.split(" ")
251
+ if type(word_place) is str:
252
+ word_place = [i for i, word in enumerate(split_text) if word_place == word]
253
+ elif type(word_place) is int:
254
+ word_place = [word_place]
255
+ out = []
256
+ if len(word_place) > 0:
257
+ words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
258
+ cur_len, ptr = 0, 0
259
+
260
+ for i in range(len(words_encode)):
261
+ cur_len += len(words_encode[i])
262
+ if ptr in word_place:
263
+ out.append(i + 1)
264
+ if cur_len >= len(split_text[ptr]):
265
+ ptr += 1
266
+ cur_len = 0
267
+ return np.array(out)
268
+
269
+
270
+ def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77):
271
+ words_x = x.split(' ')
272
+ words_y = y.split(' ')
273
+ if len(words_x) != len(words_y):
274
+ raise ValueError(f"attention replacement edit can only be applied on prompts with the same length"
275
+ f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words.")
276
+ inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]]
277
+ inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace]
278
+ inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace]
279
+ mapper = np.zeros((max_len, max_len))
280
+ i = j = 0
281
+ cur_inds = 0
282
+ while i < max_len and j < max_len:
283
+ if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i:
284
+ inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds]
285
+ if len(inds_source_) == len(inds_target_):
286
+ mapper[inds_source_, inds_target_] = 1
287
+ else:
288
+ ratio = 1 / len(inds_target_)
289
+ for i_t in inds_target_:
290
+ mapper[inds_source_, i_t] = ratio
291
+ cur_inds += 1
292
+ i += len(inds_source_)
293
+ j += len(inds_target_)
294
+ elif cur_inds < len(inds_source):
295
+ mapper[i, j] = 1
296
+ i += 1
297
+ j += 1
298
+ else:
299
+ mapper[j, j] = 1
300
+ i += 1
301
+ j += 1
302
+
303
+ return torch.from_numpy(mapper).float()
304
+
305
+
306
+
307
+ def get_replacement_mapper(prompts, tokenizer, max_len=77):
308
+ x_seq = prompts[0]
309
+ mappers = []
310
+ for i in range(1, len(prompts)):
311
+ mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len)
312
+ mappers.append(mapper)
313
+ return torch.stack(mappers)
314
+
utils.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ def is_google_colab():
2
+ try:
3
+ import google.colab
4
+ return True
5
+ except:
6
+ return False