luluxxx commited on
Commit
9981cb4
1 Parent(s): efe2cc1

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +17 -6
pipeline.py CHANGED
@@ -378,13 +378,24 @@ class AnimateDiffControlNetPipeline(
378
 
379
  if not isinstance(image, torch.Tensor):
380
  image = self.feature_extractor(image, return_tensors="pt").pixel_values
381
-
382
  image = image.to(device=device, dtype=dtype)
383
- image_embeds = self.image_encoder(image).image_embeds
384
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
385
-
386
- uncond_image_embeds = torch.zeros_like(image_embeds)
387
- return image_embeds, uncond_image_embeds
 
 
 
 
 
 
 
 
 
 
 
388
 
389
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
390
  def prepare_ip_adapter_image_embeds(
 
378
 
379
  if not isinstance(image, torch.Tensor):
380
  image = self.feature_extractor(image, return_tensors="pt").pixel_values
381
+
382
  image = image.to(device=device, dtype=dtype)
383
+ if output_hidden_states:
384
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
385
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
386
+ uncond_image_enc_hidden_states = self.image_encoder(
387
+ torch.zeros_like(image), output_hidden_states=True
388
+ ).hidden_states[-2]
389
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
390
+ num_images_per_prompt, dim=0
391
+ )
392
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
393
+ else:
394
+ image_embeds = self.image_encoder(image).image_embeds
395
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
396
+ uncond_image_embeds = torch.zeros_like(image_embeds)
397
+
398
+ return image_embeds, uncond_image_embeds
399
 
400
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
401
  def prepare_ip_adapter_image_embeds(