Yinhong Liu commited on
Commit
3dfb2f9
·
1 Parent(s): e5487ed

sana pipeline

Browse files
Files changed (3) hide show
  1. app.py +9 -9
  2. sid/pipeline_sid_sana.py +83 -242
  3. sid/pipeline_sid_sd3.py +36 -17
app.py CHANGED
@@ -9,10 +9,6 @@ import torch
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
 
17
  MODEL_OPTIONS = {
18
  "SiD-Flow-SD3-medium": "YGu1998/SiD-Flow-SD3-medium",
@@ -33,16 +29,19 @@ MODEL_OPTIONS = {
33
 
34
  def load_model(model_choice):
35
  model_repo_id = MODEL_OPTIONS[model_choice]
 
36
  if "Sana" in model_choice:
37
- pipe = SiDSanaPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
 
 
38
  elif "SD3" in model_choice:
39
- pipe = SiDSD3Pipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
40
  elif "Flux" in model_choice:
41
- pipe = SiDFluxPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
42
  else:
43
  raise ValueError(f"Unknown model type for: {model_choice}")
44
  pipe = pipe.to(device)
45
- return pipe
46
 
47
 
48
  MAX_SEED = np.iinfo(np.int32).max
@@ -65,7 +64,7 @@ def infer(
65
 
66
  generator = torch.Generator().manual_seed(seed)
67
 
68
- pipe = load_model(model_choice)
69
 
70
  image = pipe(
71
  prompt=prompt,
@@ -74,6 +73,7 @@ def infer(
74
  width=width,
75
  height=height,
76
  generator=generator,
 
77
  ).images[0]
78
 
79
  return image, seed
 
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
 
 
 
 
12
 
13
  MODEL_OPTIONS = {
14
  "SiD-Flow-SD3-medium": "YGu1998/SiD-Flow-SD3-medium",
 
29
 
30
  def load_model(model_choice):
31
  model_repo_id = MODEL_OPTIONS[model_choice]
32
+ time_scale = 1000.0
33
  if "Sana" in model_choice:
34
+ pipe = SiDSanaPipeline.from_pretrained(model_repo_id, torch_dtype=torch.float16)
35
+ if "Sprint" in model_choice:
36
+ time_scale = 1.0
37
  elif "SD3" in model_choice:
38
+ pipe = SiDSD3Pipeline.from_pretrained(model_repo_id, torch_dtype=torch.float16)
39
  elif "Flux" in model_choice:
40
+ pipe = SiDFluxPipeline.from_pretrained(model_repo_id, torch_dtype=torch.float16)
41
  else:
42
  raise ValueError(f"Unknown model type for: {model_choice}")
43
  pipe = pipe.to(device)
44
+ return pipe, time_scale
45
 
46
 
47
  MAX_SEED = np.iinfo(np.int32).max
 
64
 
65
  generator = torch.Generator().manual_seed(seed)
66
 
67
+ pipe, time_scale = load_model(model_choice)
68
 
69
  image = pipe(
70
  prompt=prompt,
 
73
  width=width,
74
  height=height,
75
  generator=generator,
76
+ time_scale=time_scale,
77
  ).images[0]
78
 
79
  return image, seed
sid/pipeline_sid_sana.py CHANGED
@@ -700,141 +700,27 @@ class SiDSanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
700
  return self._interrupt
701
 
702
  @torch.no_grad()
703
- @replace_example_docstring(EXAMPLE_DOC_STRING)
704
  def __call__(
705
  self,
706
  prompt: Union[str, List[str]] = None,
707
- negative_prompt: str = "",
708
- num_inference_steps: int = 20,
709
- timesteps: List[int] = None,
710
- sigmas: List[float] = None,
711
- guidance_scale: float = 4.5,
712
  num_images_per_prompt: Optional[int] = 1,
713
- height: int = 1024,
714
- width: int = 1024,
715
- eta: float = 0.0,
716
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
717
- latents: Optional[torch.Tensor] = None,
718
- prompt_embeds: Optional[torch.Tensor] = None,
719
- prompt_attention_mask: Optional[torch.Tensor] = None,
720
- negative_prompt_embeds: Optional[torch.Tensor] = None,
721
- negative_prompt_attention_mask: Optional[torch.Tensor] = None,
722
  output_type: Optional[str] = "pil",
723
  return_dict: bool = True,
724
- clean_caption: bool = False,
725
- use_resolution_binning: bool = True,
726
- attention_kwargs: Optional[Dict[str, Any]] = None,
727
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
728
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
729
- max_sequence_length: int = 300,
730
- complex_human_instruction: List[str] = [
731
- "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:",
732
- "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.",
733
- "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.",
734
- "Here are examples of how to transform or refine prompts:",
735
- "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.",
736
- "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.",
737
- "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:",
738
- "User Prompt: ",
739
- ],
740
- ) -> Union[SiDPipelineOutput, Tuple]:
741
- """
742
- Function invoked when calling the pipeline for generation.
743
-
744
- Args:
745
- prompt (`str` or `List[str]`, *optional*):
746
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
747
- instead.
748
- negative_prompt (`str` or `List[str]`, *optional*):
749
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
750
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
751
- less than `1`).
752
- num_inference_steps (`int`, *optional*, defaults to 20):
753
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
754
- expense of slower inference.
755
- timesteps (`List[int]`, *optional*):
756
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
757
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
758
- passed will be used. Must be in descending order.
759
- sigmas (`List[float]`, *optional*):
760
- Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
761
- their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
762
- will be used.
763
- guidance_scale (`float`, *optional*, defaults to 4.5):
764
- Guidance scale as defined in [Classifier-Free Diffusion
765
- Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
766
- of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
767
- `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
768
- the text `prompt`, usually at the expense of lower image quality.
769
- num_images_per_prompt (`int`, *optional*, defaults to 1):
770
- The number of images to generate per prompt.
771
- height (`int`, *optional*, defaults to self.unet.config.sample_size):
772
- The height in pixels of the generated image.
773
- width (`int`, *optional*, defaults to self.unet.config.sample_size):
774
- The width in pixels of the generated image.
775
- eta (`float`, *optional*, defaults to 0.0):
776
- Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
777
- applies to [`schedulers.DDIMScheduler`], will be ignored for others.
778
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
779
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
780
- to make generation deterministic.
781
- latents (`torch.Tensor`, *optional*):
782
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
783
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
784
- tensor will ge generated by sampling using the supplied random `generator`.
785
- prompt_embeds (`torch.Tensor`, *optional*):
786
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
787
- provided, text embeddings will be generated from `prompt` input argument.
788
- prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
789
- negative_prompt_embeds (`torch.Tensor`, *optional*):
790
- Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
791
- provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
792
- negative_prompt_attention_mask (`torch.Tensor`, *optional*):
793
- Pre-generated attention mask for negative text embeddings.
794
- output_type (`str`, *optional*, defaults to `"pil"`):
795
- The output format of the generate image. Choose between
796
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
797
- return_dict (`bool`, *optional*, defaults to `True`):
798
- Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
799
- attention_kwargs:
800
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
801
- `self.processor` in
802
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
803
- clean_caption (`bool`, *optional*, defaults to `True`):
804
- Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
805
- be installed. If the dependencies are not installed, the embeddings will be created from the raw
806
- prompt.
807
- use_resolution_binning (`bool` defaults to `True`):
808
- If set to `True`, the requested height and width are first mapped to the closest resolutions using
809
- `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
810
- the requested resolution. Useful for generating non-square images.
811
- callback_on_step_end (`Callable`, *optional*):
812
- A function that calls at the end of each denoising steps during the inference. The function is called
813
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
814
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
815
- `callback_on_step_end_tensor_inputs`.
816
- callback_on_step_end_tensor_inputs (`List`, *optional*):
817
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
818
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
819
- `._callback_tensor_inputs` attribute of your pipeline class.
820
- max_sequence_length (`int` defaults to `300`):
821
- Maximum sequence length to use with the `prompt`.
822
- complex_human_instruction (`List[str]`, *optional*):
823
- Instructions for complex human attention:
824
- https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55.
825
-
826
- Examples:
827
-
828
- Returns:
829
- [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] or `tuple`:
830
- If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] is returned,
831
- otherwise a `tuple` is returned where the first element is a list with the generated images
832
- """
833
-
834
- if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
835
- callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
836
-
837
- # 1. Check inputs. Raise error if not correct
838
  if use_resolution_binning:
839
  if self.transformer.config.sample_size == 128:
840
  aspect_ratio_bin = ASPECT_RATIO_4096_BIN
@@ -848,24 +734,24 @@ class SiDSanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
848
  raise ValueError("Invalid sample size")
849
  orig_height, orig_width = height, width
850
  height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
 
 
 
851
 
 
852
  self.check_inputs(
853
  prompt,
854
  height,
855
  width,
856
- callback_on_step_end_tensor_inputs,
857
- negative_prompt,
858
- prompt_embeds,
859
- negative_prompt_embeds,
860
- prompt_attention_mask,
861
- negative_prompt_attention_mask,
862
  )
863
 
864
  self._guidance_scale = guidance_scale
865
- self._attention_kwargs = attention_kwargs
866
  self._interrupt = False
867
 
868
- # 2. Default height and width to transformer
869
  if prompt is not None and isinstance(prompt, str):
870
  batch_size = 1
871
  elif prompt is not None and isinstance(prompt, list):
@@ -874,134 +760,89 @@ class SiDSanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
874
  batch_size = prompt_embeds.shape[0]
875
 
876
  device = self._execution_device
877
- lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
878
 
879
- # 3. Encode input prompt
880
  (
881
  prompt_embeds,
882
- prompt_attention_mask,
883
- negative_prompt_embeds,
884
- negative_prompt_attention_mask,
885
  ) = self.encode_prompt(
886
  prompt,
887
- self.do_classifier_free_guidance,
888
- negative_prompt=negative_prompt,
889
- num_images_per_prompt=num_images_per_prompt,
890
- device=device,
891
  prompt_embeds=prompt_embeds,
892
- negative_prompt_embeds=negative_prompt_embeds,
893
- prompt_attention_mask=prompt_attention_mask,
894
- negative_prompt_attention_mask=negative_prompt_attention_mask,
895
- clean_caption=clean_caption,
896
  max_sequence_length=max_sequence_length,
897
- complex_human_instruction=complex_human_instruction,
898
- lora_scale=lora_scale,
899
  )
900
- if self.do_classifier_free_guidance:
901
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
902
- prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
903
-
904
- # 4. Prepare timesteps
905
- timesteps, num_inference_steps = retrieve_timesteps(
906
- self.scheduler, num_inference_steps, device, timesteps, sigmas
907
- )
908
-
909
- # 5. Prepare latents.
910
- latent_channels = self.transformer.config.in_channels
911
  latents = self.prepare_latents(
912
  batch_size * num_images_per_prompt,
913
- latent_channels,
914
  height,
915
  width,
916
- torch.float32,
917
  device,
918
  generator,
919
  latents,
920
  )
921
 
922
- # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
923
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
924
-
925
- # 7. Denoising loop
926
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
927
- self._num_timesteps = len(timesteps)
928
-
929
- transformer_dtype = self.transformer.dtype
930
- with self.progress_bar(total=num_inference_steps) as progress_bar:
931
- for i, t in enumerate(timesteps):
932
- if self.interrupt:
933
- continue
934
-
935
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
936
-
937
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
938
- timestep = t.expand(latent_model_input.shape[0])
939
- timestep = timestep * self.transformer.config.timestep_scale
940
-
941
- # predict noise model_output
942
- noise_pred = self.transformer(
943
- latent_model_input.to(dtype=transformer_dtype),
944
- encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype),
945
- encoder_attention_mask=prompt_attention_mask,
946
- timestep=timestep,
947
- return_dict=False,
948
- attention_kwargs=self.attention_kwargs,
949
- )[0]
950
- noise_pred = noise_pred.float()
951
-
952
- # perform guidance
953
- if self.do_classifier_free_guidance:
954
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
955
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
956
-
957
- # learned sigma
958
- if self.transformer.config.out_channels // 2 == latent_channels:
959
- noise_pred = noise_pred.chunk(2, dim=1)[0]
960
-
961
- # compute previous image: x_t -> x_t-1
962
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
963
-
964
- if callback_on_step_end is not None:
965
- callback_kwargs = {}
966
- for k in callback_on_step_end_tensor_inputs:
967
- callback_kwargs[k] = locals()[k]
968
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
969
-
970
- latents = callback_outputs.pop("latents", latents)
971
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
972
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
973
-
974
- # call the callback, if provided
975
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
976
- progress_bar.update()
977
-
978
- if XLA_AVAILABLE:
979
- xm.mark_step()
980
-
981
- if output_type == "latent":
982
- image = latents
983
- else:
984
- latents = latents.to(self.vae.dtype)
985
- torch_accelerator_module = getattr(torch, get_device(), torch.cuda)
986
- oom_error = (
987
- torch.OutOfMemoryError
988
- if is_torch_version(">=", "2.5.0")
989
- else torch_accelerator_module.OutOfMemoryError
990
- )
991
- try:
992
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
993
- except oom_error as e:
994
- warnings.warn(
995
- f"{e}. \n"
996
- f"Try to use VAE tiling for large images. For example: \n"
997
- f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
998
  )
999
- if use_resolution_binning:
1000
- image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
 
 
1001
 
1002
- if not output_type == "latent":
1003
- image = self.image_processor.postprocess(image, output_type=output_type)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1004
 
 
 
 
 
 
 
 
 
1005
  # Offload all models
1006
  self.maybe_free_model_hooks()
1007
 
 
700
  return self._interrupt
701
 
702
  @torch.no_grad()
 
703
  def __call__(
704
  self,
705
  prompt: Union[str, List[str]] = None,
706
+ height: Optional[int] = None,
707
+ width: Optional[int] = None,
708
+ num_inference_steps: int = 28,
709
+ guidance_scale: float = 1.0,
 
710
  num_images_per_prompt: Optional[int] = 1,
 
 
 
711
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
712
+ latents: Optional[torch.FloatTensor] = None,
713
+ prompt_embeds: Optional[torch.FloatTensor] = None,
714
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
 
 
715
  output_type: Optional[str] = "pil",
716
  return_dict: bool = True,
 
 
 
717
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
718
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
719
+ max_sequence_length: int = 256,
720
+ noise_type: str = "fresh", # 'fresh', 'ddim', 'fixed'
721
+ time_scale: float = 1000.0,
722
+ use_resolution_binning: bool = True,
723
+ ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
724
  if use_resolution_binning:
725
  if self.transformer.config.sample_size == 128:
726
  aspect_ratio_bin = ASPECT_RATIO_4096_BIN
 
734
  raise ValueError("Invalid sample size")
735
  orig_height, orig_width = height, width
736
  height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
737
+
738
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
739
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
740
 
741
+ # 1. Check inputs. Raise error if not correct
742
  self.check_inputs(
743
  prompt,
744
  height,
745
  width,
746
+ prompt_embeds=prompt_embeds,
747
+ pooled_prompt_embeds=pooled_prompt_embeds,
748
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
 
 
 
749
  )
750
 
751
  self._guidance_scale = guidance_scale
 
752
  self._interrupt = False
753
 
754
+ # 2. Define call parameters
755
  if prompt is not None and isinstance(prompt, str):
756
  batch_size = 1
757
  elif prompt is not None and isinstance(prompt, list):
 
760
  batch_size = prompt_embeds.shape[0]
761
 
762
  device = self._execution_device
 
763
 
 
764
  (
765
  prompt_embeds,
766
+ pooled_prompt_embeds,
767
+ _, _,
 
768
  ) = self.encode_prompt(
769
  prompt,
 
 
 
 
770
  prompt_embeds=prompt_embeds,
771
+ pooled_prompt_embeds=pooled_prompt_embeds,
772
+ device=device,
773
+ num_images_per_prompt=num_images_per_prompt,
 
774
  max_sequence_length=max_sequence_length,
 
 
775
  )
776
+ # 3. Prepare latents
777
+ num_channels_latents = self.transformer.config.in_channels
 
 
 
 
 
 
 
 
 
778
  latents = self.prepare_latents(
779
  batch_size * num_images_per_prompt,
780
+ num_channels_latents,
781
  height,
782
  width,
783
+ prompt_embeds.dtype,
784
  device,
785
  generator,
786
  latents,
787
  )
788
 
789
+ # 4. SiD sampling loop
790
+ # Initialize D_x
791
+ D_x = torch.zeros_like(latents).to(latents.device)
792
+ # Use fixed noise for now (can be extended as needed)
793
+ initial_latents = latents.clone()
794
+ for i in range(num_inference_steps):
795
+ if noise_type == "fresh":
796
+ noise = (
797
+ latents if i == 0 else torch.randn_like(latents).to(latents.device)
798
+ )
799
+ elif noise_type == "ddim":
800
+ noise = (
801
+ latents if i == 0 else ((latents - (1.0 - t) * D_x) / t).detach()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
802
  )
803
+ elif noise_type == "fixed":
804
+ noise = initial_latents # Use the initial, unmodified latents
805
+ else:
806
+ raise ValueError(f"Unknown noise_type: {noise_type}")
807
 
808
+ # Compute t value, normalized to [0, 1]
809
+ init_timesteps = 999
810
+ scalar_t = float(init_timesteps) * (
811
+ 1.0 - float(i) / float(num_inference_steps)
812
+ )
813
+ t_val = scalar_t / 999.0
814
+ t = torch.full(
815
+ (latents.shape[0],), t_val, device=latents.device, dtype=latents.dtype
816
+ )
817
+ t_flattern = t.flatten()
818
+ if t.numel() > 1:
819
+ t = t.view(-1, 1, 1, 1)
820
+
821
+ latents = (1.0 - t) * D_x + t * noise
822
+ latent_model_input = latents
823
+
824
+ flow_pred = self.transformer(
825
+ hidden_states=latent_model_input,
826
+ encoder_hidden_states=prompt_embeds,
827
+ # encoder_attention_mask=prompt_attention_mask,
828
+ pooled_projections=pooled_prompt_embeds,
829
+ timestep=time_scale * t_flattern,
830
+ return_dict=False,
831
+ )[0]
832
+ D_x = latents - (
833
+ t * flow_pred
834
+ if torch.numel(t) == 1
835
+ else t.view(-1, 1, 1, 1) * flow_pred
836
+ )
837
 
838
+ # 5. Decode latent to image
839
+ image = self.vae.decode(
840
+ (D_x / self.vae.config.scaling_factor),
841
+ return_dict=False,
842
+ )[0]
843
+ if use_resolution_binning:
844
+ image = self.image_processor.resize_and_crop_tensor(image, orig_height, orig_width)
845
+ image = self.image_processor.postprocess(image, output_type=output_type)
846
  # Offload all models
847
  self.maybe_free_model_hooks()
848
 
sid/pipeline_sid_sd3.py CHANGED
@@ -54,6 +54,7 @@ else:
54
 
55
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
56
 
 
57
  # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
58
  def calculate_shift(
59
  image_seq_len,
@@ -683,7 +684,8 @@ class SiDSD3Pipeline(
683
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
684
  max_sequence_length: int = 256,
685
  use_sd3_shift: bool = False,
686
- noise_type: str = 'fresh', # 'fresh', 'ddim', 'fixed'
 
687
  ):
688
  height = height or self.default_sample_size * self.vae_scale_factor
689
  width = width or self.default_sample_size * self.vae_scale_factor
@@ -749,25 +751,33 @@ class SiDSD3Pipeline(
749
  # Use fixed noise for now (can be extended as needed)
750
  initial_latents = latents.clone()
751
  for i in range(num_inference_steps):
752
- if noise_type == 'fresh':
753
- noise = latents if i == 0 else torch.randn_like(latents).to(latents.device)
754
- elif noise_type=='ddim':
755
- noise = latents if i == 0 else ((latents - (1.0 - t) * D_x) / t).detach()
756
- elif noise_type == 'fixed':
 
 
 
 
757
  noise = initial_latents # Use the initial, unmodified latents
758
  else:
759
  raise ValueError(f"Unknown noise_type: {noise_type}")
760
-
761
  # Compute t value, normalized to [0, 1]
762
  init_timesteps = 999
763
- scalar_t = float(init_timesteps) * (1.0 - float(i) / float(num_inference_steps))
 
 
764
  t_val = scalar_t / 999.0
765
  # t_val = 1.0 - float(i) / float(num_inference_steps)
766
  if use_sd3_shift:
767
  shift = 3.0
768
  t_val = shift * t_val / (1 + (shift - 1) * t_val)
769
-
770
- t = torch.full((latents.shape[0],), t_val, device=latents.device, dtype=latents.dtype)
 
 
771
  t_flattern = t.flatten()
772
  if t.numel() > 1:
773
  t = t.view(-1, 1, 1, 1)
@@ -778,19 +788,28 @@ class SiDSD3Pipeline(
778
  flow_pred = self.transformer(
779
  hidden_states=latent_model_input,
780
  encoder_hidden_states=prompt_embeds,
781
- #encoder_attention_mask=prompt_attention_mask,
782
  pooled_projections=pooled_prompt_embeds,
783
- timestep=1000*t_flattern,
784
  return_dict=False,
785
  )[0]
786
- D_x = latents - (t * flow_pred if torch.numel(t) == 1 else t.view(-1, 1, 1, 1) * flow_pred)
787
-
 
 
 
 
788
  # 5. Decode latent to image
789
- image = self.vae.decode((D_x / self.vae.config.scaling_factor) + self.vae.config.shift_factor, return_dict=False)[0]
 
 
 
790
  image = self.image_processor.postprocess(image, output_type=output_type)
 
 
791
 
792
  # 6. Return output
793
  if not return_dict:
794
  return (image,)
795
-
796
- return SiDPipelineOutput(images=image)
 
54
 
55
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
56
 
57
+
58
  # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
59
  def calculate_shift(
60
  image_seq_len,
 
684
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
685
  max_sequence_length: int = 256,
686
  use_sd3_shift: bool = False,
687
+ noise_type: str = "fresh", # 'fresh', 'ddim', 'fixed'
688
+ time_scale: float = 1000.0,
689
  ):
690
  height = height or self.default_sample_size * self.vae_scale_factor
691
  width = width or self.default_sample_size * self.vae_scale_factor
 
751
  # Use fixed noise for now (can be extended as needed)
752
  initial_latents = latents.clone()
753
  for i in range(num_inference_steps):
754
+ if noise_type == "fresh":
755
+ noise = (
756
+ latents if i == 0 else torch.randn_like(latents).to(latents.device)
757
+ )
758
+ elif noise_type == "ddim":
759
+ noise = (
760
+ latents if i == 0 else ((latents - (1.0 - t) * D_x) / t).detach()
761
+ )
762
+ elif noise_type == "fixed":
763
  noise = initial_latents # Use the initial, unmodified latents
764
  else:
765
  raise ValueError(f"Unknown noise_type: {noise_type}")
766
+
767
  # Compute t value, normalized to [0, 1]
768
  init_timesteps = 999
769
+ scalar_t = float(init_timesteps) * (
770
+ 1.0 - float(i) / float(num_inference_steps)
771
+ )
772
  t_val = scalar_t / 999.0
773
  # t_val = 1.0 - float(i) / float(num_inference_steps)
774
  if use_sd3_shift:
775
  shift = 3.0
776
  t_val = shift * t_val / (1 + (shift - 1) * t_val)
777
+
778
+ t = torch.full(
779
+ (latents.shape[0],), t_val, device=latents.device, dtype=latents.dtype
780
+ )
781
  t_flattern = t.flatten()
782
  if t.numel() > 1:
783
  t = t.view(-1, 1, 1, 1)
 
788
  flow_pred = self.transformer(
789
  hidden_states=latent_model_input,
790
  encoder_hidden_states=prompt_embeds,
791
+ # encoder_attention_mask=prompt_attention_mask,
792
  pooled_projections=pooled_prompt_embeds,
793
+ timestep=time_scale * t_flattern,
794
  return_dict=False,
795
  )[0]
796
+ D_x = latents - (
797
+ t * flow_pred
798
+ if torch.numel(t) == 1
799
+ else t.view(-1, 1, 1, 1) * flow_pred
800
+ )
801
+
802
  # 5. Decode latent to image
803
+ image = self.vae.decode(
804
+ (D_x / self.vae.config.scaling_factor) + self.vae.config.shift_factor,
805
+ return_dict=False,
806
+ )[0]
807
  image = self.image_processor.postprocess(image, output_type=output_type)
808
+
809
+ self.maybe_free_model_hooks()
810
 
811
  # 6. Return output
812
  if not return_dict:
813
  return (image,)
814
+
815
+ return SiDPipelineOutput(images=image)