Leimingkun commited on
Commit
aa91ca4
1 Parent(s): 672430d

stylestudio

Browse files
Files changed (2) hide show
  1. ip_adapter/__init__.py +3 -6
  2. ip_adapter/ip_adapter.py +29 -0
ip_adapter/__init__.py CHANGED
@@ -1,15 +1,12 @@
1
- from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull,IPAdapterXL_CS,IPAdapter_CS
2
  from .ip_adapter import CSGO
3
- from .ip_adapter import StyleStudio_Adapter, StyleStudio_Adapter_exp
4
- from .ip_adapter import IPAdapterXL_cross_modal
5
  __all__ = [
6
  "IPAdapter",
7
  "IPAdapterPlus",
8
  "IPAdapterPlusXL",
9
  "IPAdapterXL",
 
10
  "CSGO",
11
  "StyleStudio_Adapter",
12
- "StyleStudio_Adapter_exp",
13
- "IPAdapterXL_cross_modal",
14
- "IPAdapterFull",
15
  ]
 
1
+ from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull
2
  from .ip_adapter import CSGO
3
+ from .ip_adapter import StyleStudio_Adapter
 
4
  __all__ = [
5
  "IPAdapter",
6
  "IPAdapterPlus",
7
  "IPAdapterPlusXL",
8
  "IPAdapterXL",
9
+ "IPAdapterFull",
10
  "CSGO",
11
  "StyleStudio_Adapter",
 
 
 
12
  ]
ip_adapter/ip_adapter.py CHANGED
@@ -246,6 +246,35 @@ class IPAdapter:
246
 
247
  return images
248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
  class IPAdapter_CS:
251
  def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_content_tokens=4,
 
246
 
247
  return images
248
 
249
+ class IPAdapterPlus(IPAdapter):
250
+ """IP-Adapter with fine-grained features"""
251
+
252
+ def init_proj(self):
253
+ image_proj_model = Resampler(
254
+ dim=self.pipe.unet.config.cross_attention_dim,
255
+ depth=4,
256
+ dim_head=64,
257
+ heads=12,
258
+ num_queries=self.num_tokens,
259
+ embedding_dim=self.image_encoder.config.hidden_size,
260
+ output_dim=self.pipe.unet.config.cross_attention_dim,
261
+ ff_mult=4,
262
+ ).to(self.device, dtype=torch.float16)
263
+ return image_proj_model
264
+
265
+ @torch.inference_mode()
266
+ def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
267
+ if isinstance(pil_image, Image.Image):
268
+ pil_image = [pil_image]
269
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
270
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
271
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
272
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
273
+ uncond_clip_image_embeds = self.image_encoder(
274
+ torch.zeros_like(clip_image), output_hidden_states=True
275
+ ).hidden_states[-2]
276
+ uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
277
+ return image_prompt_embeds, uncond_image_prompt_embeds
278
 
279
  class IPAdapter_CS:
280
  def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_content_tokens=4,