revi13 commited on
Commit
45ffde6
·
1 Parent(s): b5bf896

Restore ip_adapter from models/ipadapter

Browse files
ip_adapter/__init__.py ADDED
File without changes
ip_adapter/ip_adapter_faceid.py ADDED
@@ -0,0 +1,542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import torch
5
+ from diffusers import StableDiffusionPipeline
6
+ from diffusers.pipelines.controlnet import MultiControlNetModel
7
+ from PIL import Image
8
+ from safetensors import safe_open
9
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
10
+
11
+ from .attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor
12
+ from .utils import is_torch2_available, get_generator
13
+
14
+ USE_DAFAULT_ATTN = False # should be True for visualization_attnmap
15
+ if is_torch2_available() and (not USE_DAFAULT_ATTN):
16
+ from .attention_processor_faceid import (
17
+ LoRAAttnProcessor2_0 as LoRAAttnProcessor,
18
+ )
19
+ from .attention_processor_faceid import (
20
+ LoRAIPAttnProcessor2_0 as LoRAIPAttnProcessor,
21
+ )
22
+ else:
23
+ from .attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor
24
+ from .resampler import PerceiverAttention, FeedForward
25
+
26
+
27
+ class FacePerceiverResampler(torch.nn.Module):
28
+ def __init__(
29
+ self,
30
+ *,
31
+ dim=768,
32
+ depth=4,
33
+ dim_head=64,
34
+ heads=16,
35
+ embedding_dim=1280,
36
+ output_dim=768,
37
+ ff_mult=4,
38
+ ):
39
+ super().__init__()
40
+
41
+ self.proj_in = torch.nn.Linear(embedding_dim, dim)
42
+ self.proj_out = torch.nn.Linear(dim, output_dim)
43
+ self.norm_out = torch.nn.LayerNorm(output_dim)
44
+ self.layers = torch.nn.ModuleList([])
45
+ for _ in range(depth):
46
+ self.layers.append(
47
+ torch.nn.ModuleList(
48
+ [
49
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
50
+ FeedForward(dim=dim, mult=ff_mult),
51
+ ]
52
+ )
53
+ )
54
+
55
+ def forward(self, latents, x):
56
+ x = self.proj_in(x)
57
+ for attn, ff in self.layers:
58
+ latents = attn(x, latents) + latents
59
+ latents = ff(latents) + latents
60
+ latents = self.proj_out(latents)
61
+ return self.norm_out(latents)
62
+
63
+
64
+ class MLPProjModel(torch.nn.Module):
65
+ def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
66
+ super().__init__()
67
+
68
+ self.cross_attention_dim = cross_attention_dim
69
+ self.num_tokens = num_tokens
70
+
71
+ self.proj = torch.nn.Sequential(
72
+ torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
73
+ torch.nn.GELU(),
74
+ torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
75
+ )
76
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
77
+
78
+ def forward(self, id_embeds):
79
+ x = self.proj(id_embeds)
80
+ x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
81
+ x = self.norm(x)
82
+ return x
83
+
84
+
85
+ class ProjPlusModel(torch.nn.Module):
86
+ def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, clip_embeddings_dim=1280, num_tokens=4):
87
+ super().__init__()
88
+
89
+ self.cross_attention_dim = cross_attention_dim
90
+ self.num_tokens = num_tokens
91
+
92
+ self.proj = torch.nn.Sequential(
93
+ torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
94
+ torch.nn.GELU(),
95
+ torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
96
+ )
97
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
98
+
99
+ self.perceiver_resampler = FacePerceiverResampler(
100
+ dim=cross_attention_dim,
101
+ depth=4,
102
+ dim_head=64,
103
+ heads=cross_attention_dim // 64,
104
+ embedding_dim=clip_embeddings_dim,
105
+ output_dim=cross_attention_dim,
106
+ ff_mult=4,
107
+ )
108
+
109
+ def forward(self, id_embeds, clip_embeds, shortcut=False, scale=1.0):
110
+
111
+ x = self.proj(id_embeds)
112
+ x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
113
+ x = self.norm(x)
114
+ out = self.perceiver_resampler(x, clip_embeds)
115
+ if shortcut:
116
+ out = x + scale * out
117
+ return out
118
+
119
+
120
+ class IPAdapterFaceID:
121
+ def __init__(self, sd_pipe, ip_ckpt, device, lora_rank=128, num_tokens=4, torch_dtype=torch.float16):
122
+ self.device = device
123
+ self.ip_ckpt = ip_ckpt
124
+ self.lora_rank = lora_rank
125
+ self.num_tokens = num_tokens
126
+ self.torch_dtype = torch_dtype
127
+
128
+ self.pipe = sd_pipe.to(self.device)
129
+ self.set_ip_adapter()
130
+
131
+ # image proj model
132
+ self.image_proj_model = self.init_proj()
133
+
134
+ self.load_ip_adapter()
135
+
136
+ def init_proj(self):
137
+ image_proj_model = MLPProjModel(
138
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
139
+ id_embeddings_dim=512,
140
+ num_tokens=self.num_tokens,
141
+ ).to(self.device, dtype=self.torch_dtype)
142
+ return image_proj_model
143
+
144
+ def set_ip_adapter(self):
145
+ unet = self.pipe.unet
146
+ attn_procs = {}
147
+ for name in unet.attn_processors.keys():
148
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
149
+ if name.startswith("mid_block"):
150
+ hidden_size = unet.config.block_out_channels[-1]
151
+ elif name.startswith("up_blocks"):
152
+ block_id = int(name[len("up_blocks.")])
153
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
154
+ elif name.startswith("down_blocks"):
155
+ block_id = int(name[len("down_blocks.")])
156
+ hidden_size = unet.config.block_out_channels[block_id]
157
+ if cross_attention_dim is None:
158
+ attn_procs[name] = LoRAAttnProcessor(
159
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank,
160
+ ).to(self.device, dtype=self.torch_dtype)
161
+ else:
162
+ attn_procs[name] = LoRAIPAttnProcessor(
163
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, num_tokens=self.num_tokens,
164
+ ).to(self.device, dtype=self.torch_dtype)
165
+ unet.set_attn_processor(attn_procs)
166
+
167
+ def load_ip_adapter(self):
168
+ if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
169
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
170
+ with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
171
+ for key in f.keys():
172
+ if key.startswith("image_proj."):
173
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
174
+ elif key.startswith("ip_adapter."):
175
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
176
+ else:
177
+ state_dict = torch.load(self.ip_ckpt, map_location="cpu")
178
+ self.image_proj_model.load_state_dict(state_dict["image_proj"])
179
+ ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
180
+ ip_layers.load_state_dict(state_dict["ip_adapter"])
181
+
182
+ @torch.inference_mode()
183
+ def get_image_embeds(self, faceid_embeds):
184
+
185
+ faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype)
186
+ image_prompt_embeds = self.image_proj_model(faceid_embeds)
187
+ uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds))
188
+ return image_prompt_embeds, uncond_image_prompt_embeds
189
+
190
+ def set_scale(self, scale):
191
+ for attn_processor in self.pipe.unet.attn_processors.values():
192
+ if isinstance(attn_processor, LoRAIPAttnProcessor):
193
+ attn_processor.scale = scale
194
+
195
+ def generate(
196
+ self,
197
+ faceid_embeds=None,
198
+ prompt=None,
199
+ negative_prompt=None,
200
+ scale=1.0,
201
+ num_samples=4,
202
+ seed=None,
203
+ guidance_scale=7.5,
204
+ num_inference_steps=30,
205
+ **kwargs,
206
+ ):
207
+ self.set_scale(scale)
208
+
209
+
210
+ num_prompts = faceid_embeds.size(0)
211
+
212
+ if prompt is None:
213
+ prompt = "best quality, high quality"
214
+ if negative_prompt is None:
215
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
216
+
217
+ if not isinstance(prompt, List):
218
+ prompt = [prompt] * num_prompts
219
+ if not isinstance(negative_prompt, List):
220
+ negative_prompt = [negative_prompt] * num_prompts
221
+
222
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds)
223
+
224
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
225
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
226
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
227
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
228
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
229
+
230
+ with torch.inference_mode():
231
+ prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
232
+ prompt,
233
+ device=self.device,
234
+ num_images_per_prompt=num_samples,
235
+ do_classifier_free_guidance=True,
236
+ negative_prompt=negative_prompt,
237
+ )
238
+ prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
239
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
240
+
241
+ generator = get_generator(seed, self.device)
242
+
243
+ images = self.pipe(
244
+ prompt_embeds=prompt_embeds,
245
+ negative_prompt_embeds=negative_prompt_embeds,
246
+ guidance_scale=guidance_scale,
247
+ num_inference_steps=num_inference_steps,
248
+ generator=generator,
249
+ **kwargs,
250
+ ).images
251
+
252
+ return images
253
+
254
+
255
+ class IPAdapterFaceIDPlus:
256
+ def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, lora_rank=128, num_tokens=4, torch_dtype=torch.float16):
257
+ self.device = device
258
+ self.image_encoder_path = image_encoder_path
259
+ self.ip_ckpt = ip_ckpt
260
+ self.lora_rank = lora_rank
261
+ self.num_tokens = num_tokens
262
+ self.torch_dtype = torch_dtype
263
+
264
+ self.pipe = sd_pipe.to(self.device)
265
+ self.set_ip_adapter()
266
+
267
+ # load image encoder
268
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
269
+ self.device, dtype=self.torch_dtype
270
+ )
271
+ self.clip_image_processor = CLIPImageProcessor()
272
+ # image proj model
273
+ self.image_proj_model = self.init_proj()
274
+
275
+ self.load_ip_adapter()
276
+
277
+ def init_proj(self):
278
+ image_proj_model = ProjPlusModel(
279
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
280
+ id_embeddings_dim=512,
281
+ clip_embeddings_dim=self.image_encoder.config.hidden_size,
282
+ num_tokens=self.num_tokens,
283
+ ).to(self.device, dtype=self.torch_dtype)
284
+ return image_proj_model
285
+
286
+ def set_ip_adapter(self):
287
+ unet = self.pipe.unet
288
+ attn_procs = {}
289
+ for name in unet.attn_processors.keys():
290
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
291
+ if name.startswith("mid_block"):
292
+ hidden_size = unet.config.block_out_channels[-1]
293
+ elif name.startswith("up_blocks"):
294
+ block_id = int(name[len("up_blocks.")])
295
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
296
+ elif name.startswith("down_blocks"):
297
+ block_id = int(name[len("down_blocks.")])
298
+ hidden_size = unet.config.block_out_channels[block_id]
299
+ if cross_attention_dim is None:
300
+ attn_procs[name] = LoRAAttnProcessor(
301
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=self.lora_rank,
302
+ ).to(self.device, dtype=self.torch_dtype)
303
+ else:
304
+ attn_procs[name] = LoRAIPAttnProcessor(
305
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, rank=self.lora_rank, num_tokens=self.num_tokens,
306
+ ).to(self.device, dtype=self.torch_dtype)
307
+ unet.set_attn_processor(attn_procs)
308
+
309
+ def load_ip_adapter(self):
310
+ if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
311
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
312
+ with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
313
+ for key in f.keys():
314
+ if key.startswith("image_proj."):
315
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
316
+ elif key.startswith("ip_adapter."):
317
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
318
+ else:
319
+ state_dict = torch.load(self.ip_ckpt, map_location="cpu")
320
+ self.image_proj_model.load_state_dict(state_dict["image_proj"])
321
+ ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
322
+ ip_layers.load_state_dict(state_dict["ip_adapter"])
323
+
324
+ @torch.inference_mode()
325
+ def get_image_embeds(self, faceid_embeds, face_image, s_scale, shortcut):
326
+ if isinstance(face_image, Image.Image):
327
+ pil_image = [face_image]
328
+ clip_image = self.clip_image_processor(images=face_image, return_tensors="pt").pixel_values
329
+ clip_image = clip_image.to(self.device, dtype=self.torch_dtype)
330
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
331
+ uncond_clip_image_embeds = self.image_encoder(
332
+ torch.zeros_like(clip_image), output_hidden_states=True
333
+ ).hidden_states[-2]
334
+
335
+ faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype)
336
+ image_prompt_embeds = self.image_proj_model(faceid_embeds, clip_image_embeds, shortcut=shortcut, scale=s_scale)
337
+ uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds), uncond_clip_image_embeds, shortcut=shortcut, scale=s_scale)
338
+ return image_prompt_embeds, uncond_image_prompt_embeds
339
+
340
+ def set_scale(self, scale):
341
+ for attn_processor in self.pipe.unet.attn_processors.values():
342
+ if isinstance(attn_processor, LoRAIPAttnProcessor):
343
+ attn_processor.scale = scale
344
+
345
+ def generate(
346
+ self,
347
+ face_image=None,
348
+ faceid_embeds=None,
349
+ prompt=None,
350
+ negative_prompt=None,
351
+ scale=1.0,
352
+ num_samples=4,
353
+ seed=None,
354
+ guidance_scale=7.5,
355
+ num_inference_steps=30,
356
+ s_scale=1.0,
357
+ shortcut=False,
358
+ **kwargs,
359
+ ):
360
+ self.set_scale(scale)
361
+
362
+
363
+ num_prompts = faceid_embeds.size(0)
364
+
365
+ if prompt is None:
366
+ prompt = "best quality, high quality"
367
+ if negative_prompt is None:
368
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
369
+
370
+ if not isinstance(prompt, List):
371
+ prompt = [prompt] * num_prompts
372
+ if not isinstance(negative_prompt, List):
373
+ negative_prompt = [negative_prompt] * num_prompts
374
+
375
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, face_image, s_scale, shortcut)
376
+
377
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
378
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
379
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
380
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
381
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
382
+
383
+ with torch.inference_mode():
384
+ prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
385
+ prompt,
386
+ device=self.device,
387
+ num_images_per_prompt=num_samples,
388
+ do_classifier_free_guidance=True,
389
+ negative_prompt=negative_prompt,
390
+ )
391
+ prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
392
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
393
+
394
+ generator = get_generator(seed, self.device)
395
+
396
+ images = self.pipe(
397
+ prompt_embeds=prompt_embeds,
398
+ negative_prompt_embeds=negative_prompt_embeds,
399
+ guidance_scale=guidance_scale,
400
+ num_inference_steps=num_inference_steps,
401
+ generator=generator,
402
+ **kwargs,
403
+ ).images
404
+
405
+ return images
406
+
407
+
408
+ class IPAdapterFaceIDXL(IPAdapterFaceID):
409
+ """SDXL"""
410
+
411
+ def generate(
412
+ self,
413
+ faceid_embeds=None,
414
+ prompt=None,
415
+ negative_prompt=None,
416
+ scale=1.0,
417
+ num_samples=4,
418
+ seed=None,
419
+ num_inference_steps=30,
420
+ **kwargs,
421
+ ):
422
+ self.set_scale(scale)
423
+
424
+ num_prompts = faceid_embeds.size(0)
425
+
426
+ if prompt is None:
427
+ prompt = "best quality, high quality"
428
+ if negative_prompt is None:
429
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
430
+
431
+ if not isinstance(prompt, List):
432
+ prompt = [prompt] * num_prompts
433
+ if not isinstance(negative_prompt, List):
434
+ negative_prompt = [negative_prompt] * num_prompts
435
+
436
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds)
437
+
438
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
439
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
440
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
441
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
442
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
443
+
444
+ with torch.inference_mode():
445
+ (
446
+ prompt_embeds,
447
+ negative_prompt_embeds,
448
+ pooled_prompt_embeds,
449
+ negative_pooled_prompt_embeds,
450
+ ) = self.pipe.encode_prompt(
451
+ prompt,
452
+ num_images_per_prompt=num_samples,
453
+ do_classifier_free_guidance=True,
454
+ negative_prompt=negative_prompt,
455
+ )
456
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
457
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
458
+
459
+ generator = get_generator(seed, self.device)
460
+
461
+ images = self.pipe(
462
+ prompt_embeds=prompt_embeds,
463
+ negative_prompt_embeds=negative_prompt_embeds,
464
+ pooled_prompt_embeds=pooled_prompt_embeds,
465
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
466
+ num_inference_steps=num_inference_steps,
467
+ generator=generator,
468
+ **kwargs,
469
+ ).images
470
+
471
+ return images
472
+
473
+
474
+ class IPAdapterFaceIDPlusXL(IPAdapterFaceIDPlus):
475
+ """SDXL"""
476
+
477
+ def generate(
478
+ self,
479
+ face_image=None,
480
+ faceid_embeds=None,
481
+ prompt=None,
482
+ negative_prompt=None,
483
+ scale=1.0,
484
+ num_samples=4,
485
+ seed=None,
486
+ guidance_scale=7.5,
487
+ num_inference_steps=30,
488
+ s_scale=1.0,
489
+ shortcut=True,
490
+ **kwargs,
491
+ ):
492
+ self.set_scale(scale)
493
+
494
+ num_prompts = faceid_embeds.size(0)
495
+
496
+ if prompt is None:
497
+ prompt = "best quality, high quality"
498
+ if negative_prompt is None:
499
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
500
+
501
+ if not isinstance(prompt, List):
502
+ prompt = [prompt] * num_prompts
503
+ if not isinstance(negative_prompt, List):
504
+ negative_prompt = [negative_prompt] * num_prompts
505
+
506
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(faceid_embeds, face_image, s_scale, shortcut)
507
+
508
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
509
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
510
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
511
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
512
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
513
+
514
+ with torch.inference_mode():
515
+ (
516
+ prompt_embeds,
517
+ negative_prompt_embeds,
518
+ pooled_prompt_embeds,
519
+ negative_pooled_prompt_embeds,
520
+ ) = self.pipe.encode_prompt(
521
+ prompt,
522
+ num_images_per_prompt=num_samples,
523
+ do_classifier_free_guidance=True,
524
+ negative_prompt=negative_prompt,
525
+ )
526
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
527
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
528
+
529
+ generator = get_generator(seed, self.device)
530
+
531
+ images = self.pipe(
532
+ prompt_embeds=prompt_embeds,
533
+ negative_prompt_embeds=negative_prompt_embeds,
534
+ pooled_prompt_embeds=pooled_prompt_embeds,
535
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
536
+ num_inference_steps=num_inference_steps,
537
+ generator=generator,
538
+ guidance_scale=guidance_scale,
539
+ **kwargs,
540
+ ).images
541
+
542
+ return images