luluxxx commited on
Commit
61e96c8
1 Parent(s): 9981cb4

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +20 -20
pipeline.py CHANGED
@@ -373,29 +373,29 @@ class AnimateDiffControlNetPipeline(
373
  return prompt_embeds, negative_prompt_embeds
374
 
375
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
376
- def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
377
- dtype = next(self.image_encoder.parameters()).dtype
378
 
379
- if not isinstance(image, torch.Tensor):
380
- image = self.feature_extractor(image, return_tensors="pt").pixel_values
381
 
382
- image = image.to(device=device, dtype=dtype)
383
- if output_hidden_states:
384
- image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
385
- image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
386
- uncond_image_enc_hidden_states = self.image_encoder(
387
- torch.zeros_like(image), output_hidden_states=True
388
- ).hidden_states[-2]
389
- uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
390
- num_images_per_prompt, dim=0
391
- )
392
- return image_enc_hidden_states, uncond_image_enc_hidden_states
393
- else:
394
- image_embeds = self.image_encoder(image).image_embeds
395
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
396
- uncond_image_embeds = torch.zeros_like(image_embeds)
397
 
398
- return image_embeds, uncond_image_embeds
399
 
400
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
401
  def prepare_ip_adapter_image_embeds(
 
373
  return prompt_embeds, negative_prompt_embeds
374
 
375
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
376
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
377
+ dtype = next(self.image_encoder.parameters()).dtype
378
 
379
+ if not isinstance(image, torch.Tensor):
380
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
381
 
382
+ image = image.to(device=device, dtype=dtype)
383
+ if output_hidden_states:
384
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
385
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
386
+ uncond_image_enc_hidden_states = self.image_encoder(
387
+ torch.zeros_like(image), output_hidden_states=True
388
+ ).hidden_states[-2]
389
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
390
+ num_images_per_prompt, dim=0
391
+ )
392
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
393
+ else:
394
+ image_embeds = self.image_encoder(image).image_embeds
395
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
396
+ uncond_image_embeds = torch.zeros_like(image_embeds)
397
 
398
+ return image_embeds, uncond_image_embeds
399
 
400
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
401
  def prepare_ip_adapter_image_embeds(