Yinhong Liu commited on
Commit
64f514c
·
1 Parent(s): c2eb006

flux pipeline

Browse files
app.py CHANGED
@@ -8,7 +8,7 @@ from sid import SiDFluxPipeline, SiDSD3Pipeline, SiDSanaPipeline
8
  import torch
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
-
12
 
13
  MODEL_OPTIONS = {
14
  "SiD-Flow-SD3-medium": "YGu1998/SiD-Flow-SD3-medium",
@@ -31,13 +31,13 @@ def load_model(model_choice):
31
  model_repo_id = MODEL_OPTIONS[model_choice]
32
  time_scale = 1000.0
33
  if "Sana" in model_choice:
34
- pipe = SiDSanaPipeline.from_pretrained(model_repo_id, torch_dtype=torch.float16)
35
  if "Sprint" in model_choice:
36
  time_scale = 1.0
37
  elif "SD3" in model_choice:
38
- pipe = SiDSD3Pipeline.from_pretrained(model_repo_id, torch_dtype=torch.float16)
39
  elif "Flux" in model_choice:
40
- pipe = SiDFluxPipeline.from_pretrained(model_repo_id, torch_dtype=torch.float16)
41
  else:
42
  raise ValueError(f"Unknown model type for: {model_choice}")
43
  pipe = pipe.to(device)
 
8
  import torch
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ torch_dtype = torch.float16
12
 
13
  MODEL_OPTIONS = {
14
  "SiD-Flow-SD3-medium": "YGu1998/SiD-Flow-SD3-medium",
 
31
  model_repo_id = MODEL_OPTIONS[model_choice]
32
  time_scale = 1000.0
33
  if "Sana" in model_choice:
34
+ pipe = SiDSanaPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
35
  if "Sprint" in model_choice:
36
  time_scale = 1.0
37
  elif "SD3" in model_choice:
38
+ pipe = SiDSD3Pipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
39
  elif "Flux" in model_choice:
40
+ pipe = SiDFluxPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
41
  else:
42
  raise ValueError(f"Unknown model type for: {model_choice}")
43
  pipe = pipe.to(device)
sid/pipeline_sid_flux.py CHANGED
@@ -27,7 +27,12 @@ from transformers import (
27
  )
28
 
29
  from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
30
- from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
 
 
 
 
 
31
  from diffusers.models import AutoencoderKL, FluxTransformer2DModel
32
  from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
33
  from diffusers.utils import (
@@ -53,22 +58,6 @@ else:
53
 
54
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
55
 
56
- EXAMPLE_DOC_STRING = """
57
- Examples:
58
- ```py
59
- >>> import torch
60
- >>> from diffusers import FluxPipeline
61
-
62
- >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
63
- >>> pipe.to("cuda")
64
- >>> prompt = "A cat holding a sign that says hello world"
65
- >>> # Depending on the variant being used, the pipeline call will slightly vary.
66
- >>> # Refer to the pipeline documentation for more details.
67
- >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
68
- >>> image.save("flux.png")
69
- ```
70
- """
71
-
72
 
73
  def calculate_shift(
74
  image_seq_len,
@@ -116,9 +105,13 @@ def retrieve_timesteps(
116
  second element is the number of inference steps.
117
  """
118
  if timesteps is not None and sigmas is not None:
119
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
 
 
120
  if timesteps is not None:
121
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
 
 
122
  if not accepts_timesteps:
123
  raise ValueError(
124
  f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
@@ -128,7 +121,9 @@ def retrieve_timesteps(
128
  timesteps = scheduler.timesteps
129
  num_inference_steps = len(timesteps)
130
  elif sigmas is not None:
131
- accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
 
 
132
  if not accept_sigmas:
133
  raise ValueError(
134
  f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
@@ -176,7 +171,9 @@ class SiDFluxPipeline(
176
  [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
177
  """
178
 
179
- model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
 
 
180
  _optional_components = ["image_encoder", "feature_extractor"]
181
  _callback_tensor_inputs = ["latents", "prompt_embeds"]
182
 
@@ -205,12 +202,20 @@ class SiDFluxPipeline(
205
  image_encoder=image_encoder,
206
  feature_extractor=feature_extractor,
207
  )
208
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
 
 
 
 
209
  # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
210
  # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
211
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
 
 
212
  self.tokenizer_max_length = (
213
- self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
 
 
214
  )
215
  self.default_sample_size = 128
216
 
@@ -241,16 +246,24 @@ class SiDFluxPipeline(
241
  return_tensors="pt",
242
  )
243
  text_input_ids = text_inputs.input_ids
244
- untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
 
 
245
 
246
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
247
- removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
 
 
 
 
248
  logger.warning(
249
  "The following part of your input was truncated because `max_sequence_length` is set to "
250
  f" {max_sequence_length} tokens: {removed_text}"
251
  )
252
 
253
- prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
 
 
254
 
255
  dtype = self.text_encoder_2.dtype
256
  prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
@@ -259,7 +272,9 @@ class SiDFluxPipeline(
259
 
260
  # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
261
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
262
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
 
 
263
 
264
  return prompt_embeds
265
 
@@ -288,14 +303,22 @@ class SiDFluxPipeline(
288
  )
289
 
290
  text_input_ids = text_inputs.input_ids
291
- untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
292
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
293
- removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
 
 
 
 
 
 
294
  logger.warning(
295
  "The following part of your input was truncated because CLIP can only handle sequences up to"
296
  f" {self.tokenizer_max_length} tokens: {removed_text}"
297
  )
298
- prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
 
 
299
 
300
  # Use pooled output of CLIPTextModel
301
  prompt_embeds = prompt_embeds.pooler_output
@@ -381,7 +404,11 @@ class SiDFluxPipeline(
381
  # Retrieve the original scale by scaling back the LoRA layers
382
  unscale_lora_layers(self.text_encoder_2, lora_scale)
383
 
384
- dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
 
 
 
 
385
  text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
386
 
387
  return prompt_embeds, pooled_prompt_embeds, text_ids
@@ -405,19 +432,27 @@ class SiDFluxPipeline(
405
  if not isinstance(ip_adapter_image, list):
406
  ip_adapter_image = [ip_adapter_image]
407
 
408
- if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
 
 
 
409
  raise ValueError(
410
  f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
411
  )
412
 
413
  for single_ip_adapter_image in ip_adapter_image:
414
- single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
 
 
415
  image_embeds.append(single_image_embeds[None, :])
416
  else:
417
  if not isinstance(ip_adapter_image_embeds, list):
418
  ip_adapter_image_embeds = [ip_adapter_image_embeds]
419
 
420
- if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
 
 
 
421
  raise ValueError(
422
  f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
423
  )
@@ -427,7 +462,9 @@ class SiDFluxPipeline(
427
 
428
  ip_adapter_image_embeds = []
429
  for single_image_embeds in image_embeds:
430
- single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
 
 
431
  single_image_embeds = single_image_embeds.to(device=device)
432
  ip_adapter_image_embeds.append(single_image_embeds)
433
 
@@ -448,13 +485,17 @@ class SiDFluxPipeline(
448
  callback_on_step_end_tensor_inputs=None,
449
  max_sequence_length=None,
450
  ):
451
- if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
 
 
 
452
  logger.warning(
453
  f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
454
  )
455
 
456
  if callback_on_step_end_tensor_inputs is not None and not all(
457
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
 
458
  ):
459
  raise ValueError(
460
  f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
@@ -474,10 +515,18 @@ class SiDFluxPipeline(
474
  raise ValueError(
475
  "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
476
  )
477
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
478
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
479
- elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
480
- raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
 
 
 
 
 
 
 
 
481
 
482
  if negative_prompt is not None and negative_prompt_embeds is not None:
483
  raise ValueError(
@@ -500,15 +549,23 @@ class SiDFluxPipeline(
500
  )
501
 
502
  if max_sequence_length is not None and max_sequence_length > 512:
503
- raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
 
 
504
 
505
  @staticmethod
506
  def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
507
  latent_image_ids = torch.zeros(height, width, 3)
508
- latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
509
- latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
 
 
 
 
510
 
511
- latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
 
 
512
 
513
  latent_image_ids = latent_image_ids.reshape(
514
  latent_image_id_height * latent_image_id_width, latent_image_id_channels
@@ -518,9 +575,13 @@ class SiDFluxPipeline(
518
 
519
  @staticmethod
520
  def _pack_latents(latents, batch_size, num_channels_latents, height, width):
521
- latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
 
 
522
  latents = latents.permute(0, 2, 4, 1, 3, 5)
523
- latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
 
 
524
 
525
  return latents
526
 
@@ -588,7 +649,9 @@ class SiDFluxPipeline(
588
  shape = (batch_size, num_channels_latents, height, width)
589
 
590
  if latents is not None:
591
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
 
 
592
  return latents.to(device=device, dtype=dtype), latent_image_ids
593
 
594
  if isinstance(generator, list) and len(generator) != batch_size:
@@ -598,9 +661,13 @@ class SiDFluxPipeline(
598
  )
599
 
600
  latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
601
- latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
 
 
602
 
603
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
 
 
604
 
605
  return latents, latent_image_ids
606
 
@@ -625,135 +692,29 @@ class SiDFluxPipeline(
625
  return self._interrupt
626
 
627
  @torch.no_grad()
628
- @replace_example_docstring(EXAMPLE_DOC_STRING)
629
  def __call__(
630
  self,
631
  prompt: Union[str, List[str]] = None,
632
  prompt_2: Optional[Union[str, List[str]]] = None,
633
- negative_prompt: Union[str, List[str]] = None,
634
- negative_prompt_2: Optional[Union[str, List[str]]] = None,
635
  true_cfg_scale: float = 1.0,
636
  height: Optional[int] = None,
637
  width: Optional[int] = None,
638
  num_inference_steps: int = 28,
639
  sigmas: Optional[List[float]] = None,
640
- guidance_scale: float = 3.5,
641
  num_images_per_prompt: Optional[int] = 1,
642
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
643
  latents: Optional[torch.FloatTensor] = None,
644
  prompt_embeds: Optional[torch.FloatTensor] = None,
645
  pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
646
- ip_adapter_image: Optional[PipelineImageInput] = None,
647
- ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
648
- negative_ip_adapter_image: Optional[PipelineImageInput] = None,
649
- negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
650
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
651
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
652
  output_type: Optional[str] = "pil",
653
  return_dict: bool = True,
654
  joint_attention_kwargs: Optional[Dict[str, Any]] = None,
655
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
656
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
657
  max_sequence_length: int = 512,
 
658
  ):
659
- r"""
660
- Function invoked when calling the pipeline for generation.
661
-
662
- Args:
663
- prompt (`str` or `List[str]`, *optional*):
664
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
665
- instead.
666
- prompt_2 (`str` or `List[str]`, *optional*):
667
- The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
668
- will be used instead.
669
- negative_prompt (`str` or `List[str]`, *optional*):
670
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
671
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
672
- not greater than `1`).
673
- negative_prompt_2 (`str` or `List[str]`, *optional*):
674
- The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
675
- `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
676
- true_cfg_scale (`float`, *optional*, defaults to 1.0):
677
- True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and
678
- `negative_prompt` is provided.
679
- height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
680
- The height in pixels of the generated image. This is set to 1024 by default for the best results.
681
- width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
682
- The width in pixels of the generated image. This is set to 1024 by default for the best results.
683
- num_inference_steps (`int`, *optional*, defaults to 50):
684
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
685
- expense of slower inference.
686
- sigmas (`List[float]`, *optional*):
687
- Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
688
- their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
689
- will be used.
690
- guidance_scale (`float`, *optional*, defaults to 3.5):
691
- Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
692
- a model to generate images more aligned with `prompt` at the expense of lower image quality.
693
-
694
- Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
695
- the [paper](https://huggingface.co/papers/2210.03142) to learn more.
696
- num_images_per_prompt (`int`, *optional*, defaults to 1):
697
- The number of images to generate per prompt.
698
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
699
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
700
- to make generation deterministic.
701
- latents (`torch.FloatTensor`, *optional*):
702
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
703
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
704
- tensor will be generated by sampling using the supplied random `generator`.
705
- prompt_embeds (`torch.FloatTensor`, *optional*):
706
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
707
- provided, text embeddings will be generated from `prompt` input argument.
708
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
709
- Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
710
- If not provided, pooled text embeddings will be generated from `prompt` input argument.
711
- ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
712
- ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
713
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
714
- IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
715
- provided, embeddings are computed from the `ip_adapter_image` input argument.
716
- negative_ip_adapter_image:
717
- (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
718
- negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
719
- Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
720
- IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
721
- provided, embeddings are computed from the `ip_adapter_image` input argument.
722
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
723
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
724
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
725
- argument.
726
- negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
727
- Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
728
- weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
729
- input argument.
730
- output_type (`str`, *optional*, defaults to `"pil"`):
731
- The output format of the generate image. Choose between
732
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
733
- return_dict (`bool`, *optional*, defaults to `True`):
734
- Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
735
- joint_attention_kwargs (`dict`, *optional*):
736
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
737
- `self.processor` in
738
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
739
- callback_on_step_end (`Callable`, *optional*):
740
- A function that calls at the end of each denoising steps during the inference. The function is called
741
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
742
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
743
- `callback_on_step_end_tensor_inputs`.
744
- callback_on_step_end_tensor_inputs (`List`, *optional*):
745
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
746
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
747
- `._callback_tensor_inputs` attribute of your pipeline class.
748
- max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
749
-
750
- Examples:
751
-
752
- Returns:
753
- [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
754
- is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
755
- images.
756
- """
757
 
758
  height = height or self.default_sample_size * self.vae_scale_factor
759
  width = width or self.default_sample_size * self.vae_scale_factor
@@ -764,12 +725,8 @@ class SiDFluxPipeline(
764
  prompt_2,
765
  height,
766
  width,
767
- negative_prompt=negative_prompt,
768
- negative_prompt_2=negative_prompt_2,
769
  prompt_embeds=prompt_embeds,
770
- negative_prompt_embeds=negative_prompt_embeds,
771
  pooled_prompt_embeds=pooled_prompt_embeds,
772
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
773
  callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
774
  max_sequence_length=max_sequence_length,
775
  )
@@ -789,13 +746,6 @@ class SiDFluxPipeline(
789
 
790
  device = self._execution_device
791
 
792
- lora_scale = (
793
- self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
794
- )
795
- has_neg_prompt = negative_prompt is not None or (
796
- negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
797
- )
798
- do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
799
  (
800
  prompt_embeds,
801
  pooled_prompt_embeds,
@@ -808,23 +758,7 @@ class SiDFluxPipeline(
808
  device=device,
809
  num_images_per_prompt=num_images_per_prompt,
810
  max_sequence_length=max_sequence_length,
811
- lora_scale=lora_scale,
812
  )
813
- if do_true_cfg:
814
- (
815
- negative_prompt_embeds,
816
- negative_pooled_prompt_embeds,
817
- negative_text_ids,
818
- ) = self.encode_prompt(
819
- prompt=negative_prompt,
820
- prompt_2=negative_prompt_2,
821
- prompt_embeds=negative_prompt_embeds,
822
- pooled_prompt_embeds=negative_pooled_prompt_embeds,
823
- device=device,
824
- num_images_per_prompt=num_images_per_prompt,
825
- max_sequence_length=max_sequence_length,
826
- lora_scale=lora_scale,
827
- )
828
 
829
  # 4. Prepare latent variables
830
  num_channels_latents = self.transformer.config.in_channels // 4
@@ -839,147 +773,82 @@ class SiDFluxPipeline(
839
  latents,
840
  )
841
 
842
- # 5. Prepare timesteps
843
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
844
- if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
845
- sigmas = None
846
- image_seq_len = latents.shape[1]
847
- mu = calculate_shift(
848
- image_seq_len,
849
- self.scheduler.config.get("base_image_seq_len", 256),
850
- self.scheduler.config.get("max_image_seq_len", 4096),
851
- self.scheduler.config.get("base_shift", 0.5),
852
- self.scheduler.config.get("max_shift", 1.15),
853
- )
854
- timesteps, num_inference_steps = retrieve_timesteps(
855
- self.scheduler,
856
- num_inference_steps,
857
- device,
858
- sigmas=sigmas,
859
- mu=mu,
860
- )
861
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
862
- self._num_timesteps = len(timesteps)
863
 
864
- # handle guidance
865
- if self.transformer.config.guidance_embeds:
866
- guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
867
- guidance = guidance.expand(latents.shape[0])
868
- else:
869
- guidance = None
870
-
871
- if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
872
- negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
873
- ):
874
- negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
875
- negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
876
 
877
- elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
878
- negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
879
- ):
880
- ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
881
- ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
882
-
883
- if self.joint_attention_kwargs is None:
884
- self._joint_attention_kwargs = {}
885
-
886
- image_embeds = None
887
- negative_image_embeds = None
888
- if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
889
- image_embeds = self.prepare_ip_adapter_image_embeds(
890
- ip_adapter_image,
891
- ip_adapter_image_embeds,
892
- device,
893
- batch_size * num_images_per_prompt,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
894
  )
895
- if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
896
- negative_image_embeds = self.prepare_ip_adapter_image_embeds(
897
- negative_ip_adapter_image,
898
- negative_ip_adapter_image_embeds,
899
- device,
900
- batch_size * num_images_per_prompt,
901
  )
902
 
903
- # 6. Denoising loop
904
- # We set the index here to remove DtoH sync, helpful especially during compilation.
905
- # Check out more details here: https://github.com/huggingface/diffusers/pull/11696
906
- self.scheduler.set_begin_index(0)
907
- with self.progress_bar(total=num_inference_steps) as progress_bar:
908
- for i, t in enumerate(timesteps):
909
- if self.interrupt:
910
- continue
911
-
912
- self._current_timestep = t
913
- if image_embeds is not None:
914
- self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
915
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
916
- timestep = t.expand(latents.shape[0]).to(latents.dtype)
917
-
918
- with self.transformer.cache_context("cond"):
919
- noise_pred = self.transformer(
920
- hidden_states=latents,
921
- timestep=timestep / 1000,
922
- guidance=guidance,
923
- pooled_projections=pooled_prompt_embeds,
924
- encoder_hidden_states=prompt_embeds,
925
- txt_ids=text_ids,
926
- img_ids=latent_image_ids,
927
- joint_attention_kwargs=self.joint_attention_kwargs,
928
- return_dict=False,
929
- )[0]
930
-
931
- if do_true_cfg:
932
- if negative_image_embeds is not None:
933
- self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
934
-
935
- with self.transformer.cache_context("uncond"):
936
- neg_noise_pred = self.transformer(
937
- hidden_states=latents,
938
- timestep=timestep / 1000,
939
- guidance=guidance,
940
- pooled_projections=negative_pooled_prompt_embeds,
941
- encoder_hidden_states=negative_prompt_embeds,
942
- txt_ids=negative_text_ids,
943
- img_ids=latent_image_ids,
944
- joint_attention_kwargs=self.joint_attention_kwargs,
945
- return_dict=False,
946
- )[0]
947
- noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
948
-
949
- # compute the previous noisy sample x_t -> x_t-1
950
- latents_dtype = latents.dtype
951
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
952
-
953
- if latents.dtype != latents_dtype:
954
- if torch.backends.mps.is_available():
955
- # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
956
- latents = latents.to(latents_dtype)
957
-
958
- if callback_on_step_end is not None:
959
- callback_kwargs = {}
960
- for k in callback_on_step_end_tensor_inputs:
961
- callback_kwargs[k] = locals()[k]
962
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
963
-
964
- latents = callback_outputs.pop("latents", latents)
965
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
966
-
967
- # call the callback, if provided
968
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
969
- progress_bar.update()
970
-
971
- if XLA_AVAILABLE:
972
- xm.mark_step()
973
 
974
- self._current_timestep = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
975
 
976
- if output_type == "latent":
977
- image = latents
978
- else:
979
- latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
980
- latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
981
- image = self.vae.decode(latents, return_dict=False)[0]
982
- image = self.image_processor.postprocess(image, output_type=output_type)
983
 
984
  # Offload all models
985
  self.maybe_free_model_hooks()
@@ -987,4 +856,4 @@ class SiDFluxPipeline(
987
  if not return_dict:
988
  return (image,)
989
 
990
- return FluxPipelineOutput(images=image)
 
27
  )
28
 
29
  from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
30
+ from diffusers.loaders import (
31
+ FluxIPAdapterMixin,
32
+ FluxLoraLoaderMixin,
33
+ FromSingleFileMixin,
34
+ TextualInversionLoaderMixin,
35
+ )
36
  from diffusers.models import AutoencoderKL, FluxTransformer2DModel
37
  from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
38
  from diffusers.utils import (
 
58
 
59
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  def calculate_shift(
63
  image_seq_len,
 
105
  second element is the number of inference steps.
106
  """
107
  if timesteps is not None and sigmas is not None:
108
+ raise ValueError(
109
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
110
+ )
111
  if timesteps is not None:
112
+ accepts_timesteps = "timesteps" in set(
113
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
114
+ )
115
  if not accepts_timesteps:
116
  raise ValueError(
117
  f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
 
121
  timesteps = scheduler.timesteps
122
  num_inference_steps = len(timesteps)
123
  elif sigmas is not None:
124
+ accept_sigmas = "sigmas" in set(
125
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
126
+ )
127
  if not accept_sigmas:
128
  raise ValueError(
129
  f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
 
171
  [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
172
  """
173
 
174
+ model_cpu_offload_seq = (
175
+ "text_encoder->text_encoder_2->image_encoder->transformer->vae"
176
+ )
177
  _optional_components = ["image_encoder", "feature_extractor"]
178
  _callback_tensor_inputs = ["latents", "prompt_embeds"]
179
 
 
202
  image_encoder=image_encoder,
203
  feature_extractor=feature_extractor,
204
  )
205
+ self.vae_scale_factor = (
206
+ 2 ** (len(self.vae.config.block_out_channels) - 1)
207
+ if getattr(self, "vae", None)
208
+ else 8
209
+ )
210
  # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
211
  # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
212
+ self.image_processor = VaeImageProcessor(
213
+ vae_scale_factor=self.vae_scale_factor * 2
214
+ )
215
  self.tokenizer_max_length = (
216
+ self.tokenizer.model_max_length
217
+ if hasattr(self, "tokenizer") and self.tokenizer is not None
218
+ else 77
219
  )
220
  self.default_sample_size = 128
221
 
 
246
  return_tensors="pt",
247
  )
248
  text_input_ids = text_inputs.input_ids
249
+ untruncated_ids = self.tokenizer_2(
250
+ prompt, padding="longest", return_tensors="pt"
251
+ ).input_ids
252
 
253
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
254
+ text_input_ids, untruncated_ids
255
+ ):
256
+ removed_text = self.tokenizer_2.batch_decode(
257
+ untruncated_ids[:, self.tokenizer_max_length - 1 : -1]
258
+ )
259
  logger.warning(
260
  "The following part of your input was truncated because `max_sequence_length` is set to "
261
  f" {max_sequence_length} tokens: {removed_text}"
262
  )
263
 
264
+ prompt_embeds = self.text_encoder_2(
265
+ text_input_ids.to(device), output_hidden_states=False
266
+ )[0]
267
 
268
  dtype = self.text_encoder_2.dtype
269
  prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
 
272
 
273
  # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
274
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
275
+ prompt_embeds = prompt_embeds.view(
276
+ batch_size * num_images_per_prompt, seq_len, -1
277
+ )
278
 
279
  return prompt_embeds
280
 
 
303
  )
304
 
305
  text_input_ids = text_inputs.input_ids
306
+ untruncated_ids = self.tokenizer(
307
+ prompt, padding="longest", return_tensors="pt"
308
+ ).input_ids
309
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
310
+ text_input_ids, untruncated_ids
311
+ ):
312
+ removed_text = self.tokenizer.batch_decode(
313
+ untruncated_ids[:, self.tokenizer_max_length - 1 : -1]
314
+ )
315
  logger.warning(
316
  "The following part of your input was truncated because CLIP can only handle sequences up to"
317
  f" {self.tokenizer_max_length} tokens: {removed_text}"
318
  )
319
+ prompt_embeds = self.text_encoder(
320
+ text_input_ids.to(device), output_hidden_states=False
321
+ )
322
 
323
  # Use pooled output of CLIPTextModel
324
  prompt_embeds = prompt_embeds.pooler_output
 
404
  # Retrieve the original scale by scaling back the LoRA layers
405
  unscale_lora_layers(self.text_encoder_2, lora_scale)
406
 
407
+ dtype = (
408
+ self.text_encoder.dtype
409
+ if self.text_encoder is not None
410
+ else self.transformer.dtype
411
+ )
412
  text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
413
 
414
  return prompt_embeds, pooled_prompt_embeds, text_ids
 
432
  if not isinstance(ip_adapter_image, list):
433
  ip_adapter_image = [ip_adapter_image]
434
 
435
+ if (
436
+ len(ip_adapter_image)
437
+ != self.transformer.encoder_hid_proj.num_ip_adapters
438
+ ):
439
  raise ValueError(
440
  f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
441
  )
442
 
443
  for single_ip_adapter_image in ip_adapter_image:
444
+ single_image_embeds = self.encode_image(
445
+ single_ip_adapter_image, device, 1
446
+ )
447
  image_embeds.append(single_image_embeds[None, :])
448
  else:
449
  if not isinstance(ip_adapter_image_embeds, list):
450
  ip_adapter_image_embeds = [ip_adapter_image_embeds]
451
 
452
+ if (
453
+ len(ip_adapter_image_embeds)
454
+ != self.transformer.encoder_hid_proj.num_ip_adapters
455
+ ):
456
  raise ValueError(
457
  f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
458
  )
 
462
 
463
  ip_adapter_image_embeds = []
464
  for single_image_embeds in image_embeds:
465
+ single_image_embeds = torch.cat(
466
+ [single_image_embeds] * num_images_per_prompt, dim=0
467
+ )
468
  single_image_embeds = single_image_embeds.to(device=device)
469
  ip_adapter_image_embeds.append(single_image_embeds)
470
 
 
485
  callback_on_step_end_tensor_inputs=None,
486
  max_sequence_length=None,
487
  ):
488
+ if (
489
+ height % (self.vae_scale_factor * 2) != 0
490
+ or width % (self.vae_scale_factor * 2) != 0
491
+ ):
492
  logger.warning(
493
  f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
494
  )
495
 
496
  if callback_on_step_end_tensor_inputs is not None and not all(
497
+ k in self._callback_tensor_inputs
498
+ for k in callback_on_step_end_tensor_inputs
499
  ):
500
  raise ValueError(
501
  f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
 
515
  raise ValueError(
516
  "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
517
  )
518
+ elif prompt is not None and (
519
+ not isinstance(prompt, str) and not isinstance(prompt, list)
520
+ ):
521
+ raise ValueError(
522
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
523
+ )
524
+ elif prompt_2 is not None and (
525
+ not isinstance(prompt_2, str) and not isinstance(prompt_2, list)
526
+ ):
527
+ raise ValueError(
528
+ f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}"
529
+ )
530
 
531
  if negative_prompt is not None and negative_prompt_embeds is not None:
532
  raise ValueError(
 
549
  )
550
 
551
  if max_sequence_length is not None and max_sequence_length > 512:
552
+ raise ValueError(
553
+ f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}"
554
+ )
555
 
556
  @staticmethod
557
  def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
558
  latent_image_ids = torch.zeros(height, width, 3)
559
+ latent_image_ids[..., 1] = (
560
+ latent_image_ids[..., 1] + torch.arange(height)[:, None]
561
+ )
562
+ latent_image_ids[..., 2] = (
563
+ latent_image_ids[..., 2] + torch.arange(width)[None, :]
564
+ )
565
 
566
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = (
567
+ latent_image_ids.shape
568
+ )
569
 
570
  latent_image_ids = latent_image_ids.reshape(
571
  latent_image_id_height * latent_image_id_width, latent_image_id_channels
 
575
 
576
  @staticmethod
577
  def _pack_latents(latents, batch_size, num_channels_latents, height, width):
578
+ latents = latents.view(
579
+ batch_size, num_channels_latents, height // 2, 2, width // 2, 2
580
+ )
581
  latents = latents.permute(0, 2, 4, 1, 3, 5)
582
+ latents = latents.reshape(
583
+ batch_size, (height // 2) * (width // 2), num_channels_latents * 4
584
+ )
585
 
586
  return latents
587
 
 
649
  shape = (batch_size, num_channels_latents, height, width)
650
 
651
  if latents is not None:
652
+ latent_image_ids = self._prepare_latent_image_ids(
653
+ batch_size, height // 2, width // 2, device, dtype
654
+ )
655
  return latents.to(device=device, dtype=dtype), latent_image_ids
656
 
657
  if isinstance(generator, list) and len(generator) != batch_size:
 
661
  )
662
 
663
  latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
664
+ latents = self._pack_latents(
665
+ latents, batch_size, num_channels_latents, height, width
666
+ )
667
 
668
+ latent_image_ids = self._prepare_latent_image_ids(
669
+ batch_size, height // 2, width // 2, device, dtype
670
+ )
671
 
672
  return latents, latent_image_ids
673
 
 
692
  return self._interrupt
693
 
694
  @torch.no_grad()
 
695
  def __call__(
696
  self,
697
  prompt: Union[str, List[str]] = None,
698
  prompt_2: Optional[Union[str, List[str]]] = None,
 
 
699
  true_cfg_scale: float = 1.0,
700
  height: Optional[int] = None,
701
  width: Optional[int] = None,
702
  num_inference_steps: int = 28,
703
  sigmas: Optional[List[float]] = None,
704
+ guidance_scale: float = 1,
705
  num_images_per_prompt: Optional[int] = 1,
706
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
707
  latents: Optional[torch.FloatTensor] = None,
708
  prompt_embeds: Optional[torch.FloatTensor] = None,
709
  pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
 
 
 
 
 
 
710
  output_type: Optional[str] = "pil",
711
  return_dict: bool = True,
712
  joint_attention_kwargs: Optional[Dict[str, Any]] = None,
713
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
714
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
715
  max_sequence_length: int = 512,
716
+ noise_type: str = "fresh", # 'fresh', 'ddim', 'fixed'
717
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
718
 
719
  height = height or self.default_sample_size * self.vae_scale_factor
720
  width = width or self.default_sample_size * self.vae_scale_factor
 
725
  prompt_2,
726
  height,
727
  width,
 
 
728
  prompt_embeds=prompt_embeds,
 
729
  pooled_prompt_embeds=pooled_prompt_embeds,
 
730
  callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
731
  max_sequence_length=max_sequence_length,
732
  )
 
746
 
747
  device = self._execution_device
748
 
 
 
 
 
 
 
 
749
  (
750
  prompt_embeds,
751
  pooled_prompt_embeds,
 
758
  device=device,
759
  num_images_per_prompt=num_images_per_prompt,
760
  max_sequence_length=max_sequence_length,
 
761
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
762
 
763
  # 4. Prepare latent variables
764
  num_channels_latents = self.transformer.config.in_channels // 4
 
773
  latents,
774
  )
775
 
776
+ # Denoising loop
777
+ D_x = torch.zeros_like(latents).to(latents.device)
778
+ initial_latents = latents.clone() if noise_type == "fixed" else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
779
 
780
+ for i in range(num_inference_steps):
 
 
 
 
 
 
 
 
 
 
 
781
 
782
+ if noise_type == "fresh":
783
+ noise = (
784
+ latents if i == 0 else torch.randn_like(latents).to(latents.device)
785
+ )
786
+ elif noise_type == "ddim":
787
+ noise = (
788
+ latents if i == 0 else ((latents - (1.0 - t) * D_x) / t).detach()
789
+ )
790
+ elif noise_type == "fixed":
791
+ noise = initial_latents # Use the initial, unmodified latents
792
+ else:
793
+ raise ValueError(f"Unknown noise_type: {noise_type}")
794
+
795
+ with torch.no_grad():
796
+ # Compute timestep t for current denoising step, normalized to [0, 1]
797
+ scalar_t = 999.0 * (1.0 - float(i) / float(num_inference_steps - 1))
798
+ t_val = scalar_t / 999.0
799
+ t = torch.full(
800
+ (latents.shape[0],),
801
+ t_val,
802
+ device=latents.device,
803
+ dtype=latents.dtype,
804
+ )
805
+ if t.numel() > 1:
806
+ t = t.view(-1, 1, 1, 1)
807
+
808
+ latents = (1.0 - t) * D_x + t * noise
809
+ latent_image_ids = self._prepare_latent_image_ids(
810
+ latents.shape[0],
811
+ latents.shape[2] // 2,
812
+ latents.shape[3] // 2,
813
+ latents.device,
814
+ latents.dtype,
815
  )
816
+ packed_latents = self._pack_latents(
817
+ latents,
818
+ batch_size=latents.shape[0],
819
+ num_channels_latents=latents.shape[1],
820
+ height=latents.shape[2],
821
+ width=latents.shape[3],
822
  )
823
 
824
+ guidance = torch.tensor([guidance_scale], device=device)
825
+ guidance = guidance.expand(latents.shape[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
826
 
827
+ flow_pred = self.transformer(
828
+ hidden_states=packed_latents,
829
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
830
+ timestep=t.view(-1), # timesteps / 1000.0,
831
+ guidance=guidance,
832
+ pooled_projections=pooled_prompt_embeds,
833
+ encoder_hidden_states=prompt_embeds,
834
+ txt_ids=text_ids,
835
+ img_ids=latent_image_ids,
836
+ return_dict=False,
837
+ )[0]
838
+
839
+ flow_pred = self._unpack_latents(
840
+ flow_pred,
841
+ height=height * self.vae_scale_factor,
842
+ width=width * self.vae_scale_factor,
843
+ vae_scale_factor=self.vae_scale_factor,
844
+ )
845
+ D_x = latents - t.view(-1, 1, 1, 1) * flow_pred
846
 
847
+ latents = (
848
+ latents / self.vae.config.scaling_factor
849
+ ) + self.vae.config.shift_factor
850
+ image = self.vae.decode(latents, return_dict=False)[0]
851
+ image = self.image_processor.postprocess(image, output_type=output_type)
 
 
852
 
853
  # Offload all models
854
  self.maybe_free_model_hooks()
 
856
  if not return_dict:
857
  return (image,)
858
 
859
+ return SiDPipelineOutput(images=image)
sid/pipeline_sid_sana.py CHANGED
@@ -159,9 +159,13 @@ def retrieve_timesteps(
159
  second element is the number of inference steps.
160
  """
161
  if timesteps is not None and sigmas is not None:
162
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
 
 
163
  if timesteps is not None:
164
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
 
 
165
  if not accepts_timesteps:
166
  raise ValueError(
167
  f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
@@ -171,7 +175,9 @@ def retrieve_timesteps(
171
  timesteps = scheduler.timesteps
172
  num_inference_steps = len(timesteps)
173
  elif sigmas is not None:
174
- accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
 
 
175
  if not accept_sigmas:
176
  raise ValueError(
177
  f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
@@ -209,7 +215,11 @@ class SiDSanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
209
  super().__init__()
210
 
211
  self.register_modules(
212
- tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
 
 
 
 
213
  )
214
 
215
  self.vae_scale_factor = (
@@ -217,7 +227,9 @@ class SiDSanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
217
  if hasattr(self, "vae") and self.vae is not None
218
  else 32
219
  )
220
- self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
 
 
221
 
222
  def enable_vae_slicing(self):
223
  r"""
@@ -301,7 +313,9 @@ class SiDSanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
301
  prompt_attention_mask = text_inputs.attention_mask
302
  prompt_attention_mask = prompt_attention_mask.to(device)
303
 
304
- prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
 
 
305
  prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device)
306
 
307
  return prompt_embeds, prompt_attention_mask
@@ -398,33 +412,51 @@ class SiDSanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
398
  bs_embed, seq_len, _ = prompt_embeds.shape
399
  # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
400
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
401
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
 
 
402
  prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
403
  prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
404
 
405
  # get unconditional embeddings for classifier free guidance
406
  if do_classifier_free_guidance and negative_prompt_embeds is None:
407
- negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
408
- negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
409
- prompt=negative_prompt,
410
- device=device,
411
- dtype=dtype,
412
- clean_caption=clean_caption,
413
- max_sequence_length=max_sequence_length,
414
- complex_human_instruction=False,
 
 
 
 
 
 
415
  )
416
 
417
  if do_classifier_free_guidance:
418
  # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
419
  seq_len = negative_prompt_embeds.shape[1]
420
 
421
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
 
 
422
 
423
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
424
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
 
 
 
 
425
 
426
- negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
427
- negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
 
 
 
 
428
  else:
429
  negative_prompt_embeds = None
430
  negative_prompt_attention_mask = None
@@ -434,7 +466,12 @@ class SiDSanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
434
  # Retrieve the original scale by scaling back the LoRA layers
435
  unscale_lora_layers(self.text_encoder, lora_scale)
436
 
437
- return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
 
 
 
 
 
438
 
439
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
440
  def prepare_extra_step_kwargs(self, generator, eta):
@@ -443,13 +480,17 @@ class SiDSanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
443
  # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
444
  # and should be between [0, 1]
445
 
446
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
 
 
447
  extra_step_kwargs = {}
448
  if accepts_eta:
449
  extra_step_kwargs["eta"] = eta
450
 
451
  # check if the scheduler accepts generator
452
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
 
 
453
  if accepts_generator:
454
  extra_step_kwargs["generator"] = generator
455
  return extra_step_kwargs
@@ -467,10 +508,13 @@ class SiDSanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
467
  negative_prompt_attention_mask=None,
468
  ):
469
  if height % 32 != 0 or width % 32 != 0:
470
- raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
 
 
471
 
472
  if callback_on_step_end_tensor_inputs is not None and not all(
473
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
 
474
  ):
475
  raise ValueError(
476
  f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
@@ -485,8 +529,12 @@ class SiDSanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
485
  raise ValueError(
486
  "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
487
  )
488
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
489
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
 
 
 
 
490
 
491
  if prompt is not None and negative_prompt_embeds is not None:
492
  raise ValueError(
@@ -501,10 +549,17 @@ class SiDSanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
501
  )
502
 
503
  if prompt_embeds is not None and prompt_attention_mask is None:
504
- raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
 
 
505
 
506
- if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
507
- raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
 
 
 
 
 
508
 
509
  if prompt_embeds is not None and negative_prompt_embeds is not None:
510
  if prompt_embeds.shape != negative_prompt_embeds.shape:
@@ -523,12 +578,16 @@ class SiDSanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
523
  # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
524
  def _text_preprocessing(self, text, clean_caption=False):
525
  if clean_caption and not is_bs4_available():
526
- logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
 
 
527
  logger.warning("Setting `clean_caption` to False...")
528
  clean_caption = False
529
 
530
  if clean_caption and not is_ftfy_available():
531
- logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
 
 
532
  logger.warning("Setting `clean_caption` to False...")
533
  clean_caption = False
534
 
@@ -616,13 +675,17 @@ class SiDSanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
616
  # "123456.."
617
  caption = re.sub(r"\b\d{6,}\b", "", caption)
618
  # filenames:
619
- caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
 
 
620
 
621
  #
622
  caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
623
  caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
624
 
625
- caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
 
 
626
  caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
627
 
628
  # this-is-my-cute-cat / this_is_my_cute_cat
@@ -640,10 +703,14 @@ class SiDSanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
640
  caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
641
  caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
642
  caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
643
- caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
 
 
644
  caption = re.sub(r"\bpage\s+\d+\b", "", caption)
645
 
646
- caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
 
 
647
 
648
  caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
649
 
@@ -660,7 +727,17 @@ class SiDSanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
660
 
661
  return caption.strip()
662
 
663
- def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
 
 
 
 
 
 
 
 
 
 
664
  if latents is not None:
665
  return latents.to(device=device, dtype=dtype)
666
 
@@ -733,8 +810,10 @@ class SiDSanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
733
  else:
734
  raise ValueError("Invalid sample size")
735
  orig_height, orig_width = height, width
736
- height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
737
-
 
 
738
  if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
739
  callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
740
 
@@ -764,7 +843,8 @@ class SiDSanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
764
  (
765
  prompt_embeds,
766
  prompt_attention_mask,
767
- _, _,
 
768
  ) = self.encode_prompt(
769
  prompt,
770
  prompt_embeds=prompt_embeds,
@@ -840,7 +920,9 @@ class SiDSanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
840
  return_dict=False,
841
  )[0]
842
  if use_resolution_binning:
843
- image = self.image_processor.resize_and_crop_tensor(image, orig_height, orig_width)
 
 
844
  image = self.image_processor.postprocess(image, output_type=output_type)
845
  # Offload all models
846
  self.maybe_free_model_hooks()
 
159
  second element is the number of inference steps.
160
  """
161
  if timesteps is not None and sigmas is not None:
162
+ raise ValueError(
163
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
164
+ )
165
  if timesteps is not None:
166
+ accepts_timesteps = "timesteps" in set(
167
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
168
+ )
169
  if not accepts_timesteps:
170
  raise ValueError(
171
  f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
 
175
  timesteps = scheduler.timesteps
176
  num_inference_steps = len(timesteps)
177
  elif sigmas is not None:
178
+ accept_sigmas = "sigmas" in set(
179
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
180
+ )
181
  if not accept_sigmas:
182
  raise ValueError(
183
  f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
 
215
  super().__init__()
216
 
217
  self.register_modules(
218
+ tokenizer=tokenizer,
219
+ text_encoder=text_encoder,
220
+ vae=vae,
221
+ transformer=transformer,
222
+ scheduler=scheduler,
223
  )
224
 
225
  self.vae_scale_factor = (
 
227
  if hasattr(self, "vae") and self.vae is not None
228
  else 32
229
  )
230
+ self.image_processor = PixArtImageProcessor(
231
+ vae_scale_factor=self.vae_scale_factor
232
+ )
233
 
234
  def enable_vae_slicing(self):
235
  r"""
 
313
  prompt_attention_mask = text_inputs.attention_mask
314
  prompt_attention_mask = prompt_attention_mask.to(device)
315
 
316
+ prompt_embeds = self.text_encoder(
317
+ text_input_ids.to(device), attention_mask=prompt_attention_mask
318
+ )
319
  prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device)
320
 
321
  return prompt_embeds, prompt_attention_mask
 
412
  bs_embed, seq_len, _ = prompt_embeds.shape
413
  # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
414
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
415
+ prompt_embeds = prompt_embeds.view(
416
+ bs_embed * num_images_per_prompt, seq_len, -1
417
+ )
418
  prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
419
  prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
420
 
421
  # get unconditional embeddings for classifier free guidance
422
  if do_classifier_free_guidance and negative_prompt_embeds is None:
423
+ negative_prompt = (
424
+ [negative_prompt] * batch_size
425
+ if isinstance(negative_prompt, str)
426
+ else negative_prompt
427
+ )
428
+ negative_prompt_embeds, negative_prompt_attention_mask = (
429
+ self._get_gemma_prompt_embeds(
430
+ prompt=negative_prompt,
431
+ device=device,
432
+ dtype=dtype,
433
+ clean_caption=clean_caption,
434
+ max_sequence_length=max_sequence_length,
435
+ complex_human_instruction=False,
436
+ )
437
  )
438
 
439
  if do_classifier_free_guidance:
440
  # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
441
  seq_len = negative_prompt_embeds.shape[1]
442
 
443
+ negative_prompt_embeds = negative_prompt_embeds.to(
444
+ dtype=dtype, device=device
445
+ )
446
 
447
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
448
+ 1, num_images_per_prompt, 1
449
+ )
450
+ negative_prompt_embeds = negative_prompt_embeds.view(
451
+ batch_size * num_images_per_prompt, seq_len, -1
452
+ )
453
 
454
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(
455
+ bs_embed, -1
456
+ )
457
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(
458
+ num_images_per_prompt, 1
459
+ )
460
  else:
461
  negative_prompt_embeds = None
462
  negative_prompt_attention_mask = None
 
466
  # Retrieve the original scale by scaling back the LoRA layers
467
  unscale_lora_layers(self.text_encoder, lora_scale)
468
 
469
+ return (
470
+ prompt_embeds,
471
+ prompt_attention_mask,
472
+ negative_prompt_embeds,
473
+ negative_prompt_attention_mask,
474
+ )
475
 
476
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
477
  def prepare_extra_step_kwargs(self, generator, eta):
 
480
  # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
481
  # and should be between [0, 1]
482
 
483
+ accepts_eta = "eta" in set(
484
+ inspect.signature(self.scheduler.step).parameters.keys()
485
+ )
486
  extra_step_kwargs = {}
487
  if accepts_eta:
488
  extra_step_kwargs["eta"] = eta
489
 
490
  # check if the scheduler accepts generator
491
+ accepts_generator = "generator" in set(
492
+ inspect.signature(self.scheduler.step).parameters.keys()
493
+ )
494
  if accepts_generator:
495
  extra_step_kwargs["generator"] = generator
496
  return extra_step_kwargs
 
508
  negative_prompt_attention_mask=None,
509
  ):
510
  if height % 32 != 0 or width % 32 != 0:
511
+ raise ValueError(
512
+ f"`height` and `width` have to be divisible by 32 but are {height} and {width}."
513
+ )
514
 
515
  if callback_on_step_end_tensor_inputs is not None and not all(
516
+ k in self._callback_tensor_inputs
517
+ for k in callback_on_step_end_tensor_inputs
518
  ):
519
  raise ValueError(
520
  f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
 
529
  raise ValueError(
530
  "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
531
  )
532
+ elif prompt is not None and (
533
+ not isinstance(prompt, str) and not isinstance(prompt, list)
534
+ ):
535
+ raise ValueError(
536
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
537
+ )
538
 
539
  if prompt is not None and negative_prompt_embeds is not None:
540
  raise ValueError(
 
549
  )
550
 
551
  if prompt_embeds is not None and prompt_attention_mask is None:
552
+ raise ValueError(
553
+ "Must provide `prompt_attention_mask` when specifying `prompt_embeds`."
554
+ )
555
 
556
+ if (
557
+ negative_prompt_embeds is not None
558
+ and negative_prompt_attention_mask is None
559
+ ):
560
+ raise ValueError(
561
+ "Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`."
562
+ )
563
 
564
  if prompt_embeds is not None and negative_prompt_embeds is not None:
565
  if prompt_embeds.shape != negative_prompt_embeds.shape:
 
578
  # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
579
  def _text_preprocessing(self, text, clean_caption=False):
580
  if clean_caption and not is_bs4_available():
581
+ logger.warning(
582
+ BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")
583
+ )
584
  logger.warning("Setting `clean_caption` to False...")
585
  clean_caption = False
586
 
587
  if clean_caption and not is_ftfy_available():
588
+ logger.warning(
589
+ BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")
590
+ )
591
  logger.warning("Setting `clean_caption` to False...")
592
  clean_caption = False
593
 
 
675
  # "123456.."
676
  caption = re.sub(r"\b\d{6,}\b", "", caption)
677
  # filenames:
678
+ caption = re.sub(
679
+ r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption
680
+ )
681
 
682
  #
683
  caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
684
  caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
685
 
686
+ caption = re.sub(
687
+ self.bad_punct_regex, r" ", caption
688
+ ) # ***AUSVERKAUFT***, #AUSVERKAUFT
689
  caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
690
 
691
  # this-is-my-cute-cat / this_is_my_cute_cat
 
703
  caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
704
  caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
705
  caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
706
+ caption = re.sub(
707
+ r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption
708
+ )
709
  caption = re.sub(r"\bpage\s+\d+\b", "", caption)
710
 
711
+ caption = re.sub(
712
+ r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption
713
+ ) # j2d1a2a...
714
 
715
  caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
716
 
 
727
 
728
  return caption.strip()
729
 
730
+ def prepare_latents(
731
+ self,
732
+ batch_size,
733
+ num_channels_latents,
734
+ height,
735
+ width,
736
+ dtype,
737
+ device,
738
+ generator,
739
+ latents=None,
740
+ ):
741
  if latents is not None:
742
  return latents.to(device=device, dtype=dtype)
743
 
 
810
  else:
811
  raise ValueError("Invalid sample size")
812
  orig_height, orig_width = height, width
813
+ height, width = self.image_processor.classify_height_width_bin(
814
+ height, width, ratios=aspect_ratio_bin
815
+ )
816
+
817
  if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
818
  callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
819
 
 
843
  (
844
  prompt_embeds,
845
  prompt_attention_mask,
846
+ _,
847
+ _,
848
  ) = self.encode_prompt(
849
  prompt,
850
  prompt_embeds=prompt_embeds,
 
920
  return_dict=False,
921
  )[0]
922
  if use_resolution_binning:
923
+ image = self.image_processor.resize_and_crop_tensor(
924
+ image, orig_height, orig_width
925
+ )
926
  image = self.image_processor.postprocess(image, output_type=output_type)
927
  # Offload all models
928
  self.maybe_free_model_hooks()
sid/pipeline_sid_sd3.py CHANGED
@@ -749,16 +749,16 @@ class SiDSD3Pipeline(
749
  # Initialize D_x
750
  D_x = torch.zeros_like(latents).to(latents.device)
751
  # Use fixed noise for now (can be extended as needed)
752
- initial_latents = latents.clone()
753
  for i in range(num_inference_steps):
754
  if noise_type == "fresh":
755
  noise = (
756
  latents if i == 0 else torch.randn_like(latents).to(latents.device)
757
  )
758
  elif noise_type == "ddim":
759
- noise = (
760
- latents if i == 0 else ((latents - (1.0 - t) * D_x) / t).detach()
761
- )
762
  elif noise_type == "fixed":
763
  noise = initial_latents # Use the initial, unmodified latents
764
  else:
 
749
  # Initialize D_x
750
  D_x = torch.zeros_like(latents).to(latents.device)
751
  # Use fixed noise for now (can be extended as needed)
752
+ initial_latents = latents.clone() if noise_type == 'fixed' else None
753
  for i in range(num_inference_steps):
754
  if noise_type == "fresh":
755
  noise = (
756
  latents if i == 0 else torch.randn_like(latents).to(latents.device)
757
  )
758
  elif noise_type == "ddim":
759
+ noise = (
760
+ latents if i == 0 else ((latents - (1.0 - t) * D_x) / t).detach()
761
+ )
762
  elif noise_type == "fixed":
763
  noise = initial_latents # Use the initial, unmodified latents
764
  else: