Yinhong Liu commited on
Commit
c2eb006
·
1 Parent(s): 3dfb2f9

fix sana pipeline

Browse files
Files changed (1) hide show
  1. sid/pipeline_sid_sana.py +5 -6
sid/pipeline_sid_sana.py CHANGED
@@ -711,7 +711,7 @@ class SiDSanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
711
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
712
  latents: Optional[torch.FloatTensor] = None,
713
  prompt_embeds: Optional[torch.FloatTensor] = None,
714
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
715
  output_type: Optional[str] = "pil",
716
  return_dict: bool = True,
717
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
@@ -744,7 +744,7 @@ class SiDSanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
744
  height,
745
  width,
746
  prompt_embeds=prompt_embeds,
747
- pooled_prompt_embeds=pooled_prompt_embeds,
748
  callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
749
  )
750
 
@@ -763,12 +763,12 @@ class SiDSanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
763
 
764
  (
765
  prompt_embeds,
766
- pooled_prompt_embeds,
767
  _, _,
768
  ) = self.encode_prompt(
769
  prompt,
770
  prompt_embeds=prompt_embeds,
771
- pooled_prompt_embeds=pooled_prompt_embeds,
772
  device=device,
773
  num_images_per_prompt=num_images_per_prompt,
774
  max_sequence_length=max_sequence_length,
@@ -824,8 +824,7 @@ class SiDSanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
824
  flow_pred = self.transformer(
825
  hidden_states=latent_model_input,
826
  encoder_hidden_states=prompt_embeds,
827
- # encoder_attention_mask=prompt_attention_mask,
828
- pooled_projections=pooled_prompt_embeds,
829
  timestep=time_scale * t_flattern,
830
  return_dict=False,
831
  )[0]
 
711
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
712
  latents: Optional[torch.FloatTensor] = None,
713
  prompt_embeds: Optional[torch.FloatTensor] = None,
714
+ prompt_attention_mask: Optional[torch.FloatTensor] = None,
715
  output_type: Optional[str] = "pil",
716
  return_dict: bool = True,
717
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
 
744
  height,
745
  width,
746
  prompt_embeds=prompt_embeds,
747
+ prompt_attention_mask=prompt_attention_mask,
748
  callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
749
  )
750
 
 
763
 
764
  (
765
  prompt_embeds,
766
+ prompt_attention_mask,
767
  _, _,
768
  ) = self.encode_prompt(
769
  prompt,
770
  prompt_embeds=prompt_embeds,
771
+ prompt_attention_mask=prompt_attention_mask,
772
  device=device,
773
  num_images_per_prompt=num_images_per_prompt,
774
  max_sequence_length=max_sequence_length,
 
824
  flow_pred = self.transformer(
825
  hidden_states=latent_model_input,
826
  encoder_hidden_states=prompt_embeds,
827
+ encoder_attention_mask=prompt_attention_mask,
 
828
  timestep=time_scale * t_flattern,
829
  return_dict=False,
830
  )[0]