Spaces:
Sleeping
Sleeping
Yinhong Liu
commited on
Commit
·
3dfb2f9
1
Parent(s):
e5487ed
sana pipeline
Browse files- app.py +9 -9
- sid/pipeline_sid_sana.py +83 -242
- sid/pipeline_sid_sd3.py +36 -17
app.py
CHANGED
|
@@ -9,10 +9,6 @@ import torch
|
|
| 9 |
|
| 10 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 11 |
|
| 12 |
-
if torch.cuda.is_available():
|
| 13 |
-
torch_dtype = torch.float16
|
| 14 |
-
else:
|
| 15 |
-
torch_dtype = torch.float32
|
| 16 |
|
| 17 |
MODEL_OPTIONS = {
|
| 18 |
"SiD-Flow-SD3-medium": "YGu1998/SiD-Flow-SD3-medium",
|
|
@@ -33,16 +29,19 @@ MODEL_OPTIONS = {
|
|
| 33 |
|
| 34 |
def load_model(model_choice):
|
| 35 |
model_repo_id = MODEL_OPTIONS[model_choice]
|
|
|
|
| 36 |
if "Sana" in model_choice:
|
| 37 |
-
pipe = SiDSanaPipeline.from_pretrained(model_repo_id, torch_dtype=
|
|
|
|
|
|
|
| 38 |
elif "SD3" in model_choice:
|
| 39 |
-
pipe = SiDSD3Pipeline.from_pretrained(model_repo_id, torch_dtype=
|
| 40 |
elif "Flux" in model_choice:
|
| 41 |
-
pipe = SiDFluxPipeline.from_pretrained(model_repo_id, torch_dtype=
|
| 42 |
else:
|
| 43 |
raise ValueError(f"Unknown model type for: {model_choice}")
|
| 44 |
pipe = pipe.to(device)
|
| 45 |
-
return pipe
|
| 46 |
|
| 47 |
|
| 48 |
MAX_SEED = np.iinfo(np.int32).max
|
|
@@ -65,7 +64,7 @@ def infer(
|
|
| 65 |
|
| 66 |
generator = torch.Generator().manual_seed(seed)
|
| 67 |
|
| 68 |
-
pipe = load_model(model_choice)
|
| 69 |
|
| 70 |
image = pipe(
|
| 71 |
prompt=prompt,
|
|
@@ -74,6 +73,7 @@ def infer(
|
|
| 74 |
width=width,
|
| 75 |
height=height,
|
| 76 |
generator=generator,
|
|
|
|
| 77 |
).images[0]
|
| 78 |
|
| 79 |
return image, seed
|
|
|
|
| 9 |
|
| 10 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
MODEL_OPTIONS = {
|
| 14 |
"SiD-Flow-SD3-medium": "YGu1998/SiD-Flow-SD3-medium",
|
|
|
|
| 29 |
|
| 30 |
def load_model(model_choice):
|
| 31 |
model_repo_id = MODEL_OPTIONS[model_choice]
|
| 32 |
+
time_scale = 1000.0
|
| 33 |
if "Sana" in model_choice:
|
| 34 |
+
pipe = SiDSanaPipeline.from_pretrained(model_repo_id, torch_dtype=torch.float16)
|
| 35 |
+
if "Sprint" in model_choice:
|
| 36 |
+
time_scale = 1.0
|
| 37 |
elif "SD3" in model_choice:
|
| 38 |
+
pipe = SiDSD3Pipeline.from_pretrained(model_repo_id, torch_dtype=torch.float16)
|
| 39 |
elif "Flux" in model_choice:
|
| 40 |
+
pipe = SiDFluxPipeline.from_pretrained(model_repo_id, torch_dtype=torch.float16)
|
| 41 |
else:
|
| 42 |
raise ValueError(f"Unknown model type for: {model_choice}")
|
| 43 |
pipe = pipe.to(device)
|
| 44 |
+
return pipe, time_scale
|
| 45 |
|
| 46 |
|
| 47 |
MAX_SEED = np.iinfo(np.int32).max
|
|
|
|
| 64 |
|
| 65 |
generator = torch.Generator().manual_seed(seed)
|
| 66 |
|
| 67 |
+
pipe, time_scale = load_model(model_choice)
|
| 68 |
|
| 69 |
image = pipe(
|
| 70 |
prompt=prompt,
|
|
|
|
| 73 |
width=width,
|
| 74 |
height=height,
|
| 75 |
generator=generator,
|
| 76 |
+
time_scale=time_scale,
|
| 77 |
).images[0]
|
| 78 |
|
| 79 |
return image, seed
|
sid/pipeline_sid_sana.py
CHANGED
|
@@ -700,141 +700,27 @@ class SiDSanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
|
| 700 |
return self._interrupt
|
| 701 |
|
| 702 |
@torch.no_grad()
|
| 703 |
-
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 704 |
def __call__(
|
| 705 |
self,
|
| 706 |
prompt: Union[str, List[str]] = None,
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
guidance_scale: float = 4.5,
|
| 712 |
num_images_per_prompt: Optional[int] = 1,
|
| 713 |
-
height: int = 1024,
|
| 714 |
-
width: int = 1024,
|
| 715 |
-
eta: float = 0.0,
|
| 716 |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 717 |
-
latents: Optional[torch.
|
| 718 |
-
prompt_embeds: Optional[torch.
|
| 719 |
-
|
| 720 |
-
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 721 |
-
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 722 |
output_type: Optional[str] = "pil",
|
| 723 |
return_dict: bool = True,
|
| 724 |
-
clean_caption: bool = False,
|
| 725 |
-
use_resolution_binning: bool = True,
|
| 726 |
-
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 727 |
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 728 |
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 729 |
-
max_sequence_length: int =
|
| 730 |
-
|
| 731 |
-
|
| 732 |
-
|
| 733 |
-
|
| 734 |
-
"Here are examples of how to transform or refine prompts:",
|
| 735 |
-
"- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.",
|
| 736 |
-
"- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.",
|
| 737 |
-
"Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:",
|
| 738 |
-
"User Prompt: ",
|
| 739 |
-
],
|
| 740 |
-
) -> Union[SiDPipelineOutput, Tuple]:
|
| 741 |
-
"""
|
| 742 |
-
Function invoked when calling the pipeline for generation.
|
| 743 |
-
|
| 744 |
-
Args:
|
| 745 |
-
prompt (`str` or `List[str]`, *optional*):
|
| 746 |
-
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 747 |
-
instead.
|
| 748 |
-
negative_prompt (`str` or `List[str]`, *optional*):
|
| 749 |
-
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 750 |
-
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 751 |
-
less than `1`).
|
| 752 |
-
num_inference_steps (`int`, *optional*, defaults to 20):
|
| 753 |
-
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 754 |
-
expense of slower inference.
|
| 755 |
-
timesteps (`List[int]`, *optional*):
|
| 756 |
-
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 757 |
-
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 758 |
-
passed will be used. Must be in descending order.
|
| 759 |
-
sigmas (`List[float]`, *optional*):
|
| 760 |
-
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 761 |
-
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 762 |
-
will be used.
|
| 763 |
-
guidance_scale (`float`, *optional*, defaults to 4.5):
|
| 764 |
-
Guidance scale as defined in [Classifier-Free Diffusion
|
| 765 |
-
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 766 |
-
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 767 |
-
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 768 |
-
the text `prompt`, usually at the expense of lower image quality.
|
| 769 |
-
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 770 |
-
The number of images to generate per prompt.
|
| 771 |
-
height (`int`, *optional*, defaults to self.unet.config.sample_size):
|
| 772 |
-
The height in pixels of the generated image.
|
| 773 |
-
width (`int`, *optional*, defaults to self.unet.config.sample_size):
|
| 774 |
-
The width in pixels of the generated image.
|
| 775 |
-
eta (`float`, *optional*, defaults to 0.0):
|
| 776 |
-
Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
|
| 777 |
-
applies to [`schedulers.DDIMScheduler`], will be ignored for others.
|
| 778 |
-
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 779 |
-
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 780 |
-
to make generation deterministic.
|
| 781 |
-
latents (`torch.Tensor`, *optional*):
|
| 782 |
-
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 783 |
-
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 784 |
-
tensor will ge generated by sampling using the supplied random `generator`.
|
| 785 |
-
prompt_embeds (`torch.Tensor`, *optional*):
|
| 786 |
-
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 787 |
-
provided, text embeddings will be generated from `prompt` input argument.
|
| 788 |
-
prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
|
| 789 |
-
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 790 |
-
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
|
| 791 |
-
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
|
| 792 |
-
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
|
| 793 |
-
Pre-generated attention mask for negative text embeddings.
|
| 794 |
-
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 795 |
-
The output format of the generate image. Choose between
|
| 796 |
-
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 797 |
-
return_dict (`bool`, *optional*, defaults to `True`):
|
| 798 |
-
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
|
| 799 |
-
attention_kwargs:
|
| 800 |
-
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 801 |
-
`self.processor` in
|
| 802 |
-
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 803 |
-
clean_caption (`bool`, *optional*, defaults to `True`):
|
| 804 |
-
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
|
| 805 |
-
be installed. If the dependencies are not installed, the embeddings will be created from the raw
|
| 806 |
-
prompt.
|
| 807 |
-
use_resolution_binning (`bool` defaults to `True`):
|
| 808 |
-
If set to `True`, the requested height and width are first mapped to the closest resolutions using
|
| 809 |
-
`ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
|
| 810 |
-
the requested resolution. Useful for generating non-square images.
|
| 811 |
-
callback_on_step_end (`Callable`, *optional*):
|
| 812 |
-
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 813 |
-
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 814 |
-
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 815 |
-
`callback_on_step_end_tensor_inputs`.
|
| 816 |
-
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 817 |
-
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 818 |
-
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 819 |
-
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 820 |
-
max_sequence_length (`int` defaults to `300`):
|
| 821 |
-
Maximum sequence length to use with the `prompt`.
|
| 822 |
-
complex_human_instruction (`List[str]`, *optional*):
|
| 823 |
-
Instructions for complex human attention:
|
| 824 |
-
https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55.
|
| 825 |
-
|
| 826 |
-
Examples:
|
| 827 |
-
|
| 828 |
-
Returns:
|
| 829 |
-
[`~pipelines.sana.pipeline_output.SanaPipelineOutput`] or `tuple`:
|
| 830 |
-
If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] is returned,
|
| 831 |
-
otherwise a `tuple` is returned where the first element is a list with the generated images
|
| 832 |
-
"""
|
| 833 |
-
|
| 834 |
-
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 835 |
-
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 836 |
-
|
| 837 |
-
# 1. Check inputs. Raise error if not correct
|
| 838 |
if use_resolution_binning:
|
| 839 |
if self.transformer.config.sample_size == 128:
|
| 840 |
aspect_ratio_bin = ASPECT_RATIO_4096_BIN
|
|
@@ -848,24 +734,24 @@ class SiDSanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
|
| 848 |
raise ValueError("Invalid sample size")
|
| 849 |
orig_height, orig_width = height, width
|
| 850 |
height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
|
|
|
|
|
|
|
|
|
|
| 851 |
|
|
|
|
| 852 |
self.check_inputs(
|
| 853 |
prompt,
|
| 854 |
height,
|
| 855 |
width,
|
| 856 |
-
|
| 857 |
-
|
| 858 |
-
|
| 859 |
-
negative_prompt_embeds,
|
| 860 |
-
prompt_attention_mask,
|
| 861 |
-
negative_prompt_attention_mask,
|
| 862 |
)
|
| 863 |
|
| 864 |
self._guidance_scale = guidance_scale
|
| 865 |
-
self._attention_kwargs = attention_kwargs
|
| 866 |
self._interrupt = False
|
| 867 |
|
| 868 |
-
# 2.
|
| 869 |
if prompt is not None and isinstance(prompt, str):
|
| 870 |
batch_size = 1
|
| 871 |
elif prompt is not None and isinstance(prompt, list):
|
|
@@ -874,134 +760,89 @@ class SiDSanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
|
| 874 |
batch_size = prompt_embeds.shape[0]
|
| 875 |
|
| 876 |
device = self._execution_device
|
| 877 |
-
lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
|
| 878 |
|
| 879 |
-
# 3. Encode input prompt
|
| 880 |
(
|
| 881 |
prompt_embeds,
|
| 882 |
-
|
| 883 |
-
|
| 884 |
-
negative_prompt_attention_mask,
|
| 885 |
) = self.encode_prompt(
|
| 886 |
prompt,
|
| 887 |
-
self.do_classifier_free_guidance,
|
| 888 |
-
negative_prompt=negative_prompt,
|
| 889 |
-
num_images_per_prompt=num_images_per_prompt,
|
| 890 |
-
device=device,
|
| 891 |
prompt_embeds=prompt_embeds,
|
| 892 |
-
|
| 893 |
-
|
| 894 |
-
|
| 895 |
-
clean_caption=clean_caption,
|
| 896 |
max_sequence_length=max_sequence_length,
|
| 897 |
-
complex_human_instruction=complex_human_instruction,
|
| 898 |
-
lora_scale=lora_scale,
|
| 899 |
)
|
| 900 |
-
|
| 901 |
-
|
| 902 |
-
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
| 903 |
-
|
| 904 |
-
# 4. Prepare timesteps
|
| 905 |
-
timesteps, num_inference_steps = retrieve_timesteps(
|
| 906 |
-
self.scheduler, num_inference_steps, device, timesteps, sigmas
|
| 907 |
-
)
|
| 908 |
-
|
| 909 |
-
# 5. Prepare latents.
|
| 910 |
-
latent_channels = self.transformer.config.in_channels
|
| 911 |
latents = self.prepare_latents(
|
| 912 |
batch_size * num_images_per_prompt,
|
| 913 |
-
|
| 914 |
height,
|
| 915 |
width,
|
| 916 |
-
|
| 917 |
device,
|
| 918 |
generator,
|
| 919 |
latents,
|
| 920 |
)
|
| 921 |
|
| 922 |
-
#
|
| 923 |
-
|
| 924 |
-
|
| 925 |
-
#
|
| 926 |
-
|
| 927 |
-
|
| 928 |
-
|
| 929 |
-
|
| 930 |
-
|
| 931 |
-
|
| 932 |
-
|
| 933 |
-
|
| 934 |
-
|
| 935 |
-
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 936 |
-
|
| 937 |
-
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 938 |
-
timestep = t.expand(latent_model_input.shape[0])
|
| 939 |
-
timestep = timestep * self.transformer.config.timestep_scale
|
| 940 |
-
|
| 941 |
-
# predict noise model_output
|
| 942 |
-
noise_pred = self.transformer(
|
| 943 |
-
latent_model_input.to(dtype=transformer_dtype),
|
| 944 |
-
encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype),
|
| 945 |
-
encoder_attention_mask=prompt_attention_mask,
|
| 946 |
-
timestep=timestep,
|
| 947 |
-
return_dict=False,
|
| 948 |
-
attention_kwargs=self.attention_kwargs,
|
| 949 |
-
)[0]
|
| 950 |
-
noise_pred = noise_pred.float()
|
| 951 |
-
|
| 952 |
-
# perform guidance
|
| 953 |
-
if self.do_classifier_free_guidance:
|
| 954 |
-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 955 |
-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 956 |
-
|
| 957 |
-
# learned sigma
|
| 958 |
-
if self.transformer.config.out_channels // 2 == latent_channels:
|
| 959 |
-
noise_pred = noise_pred.chunk(2, dim=1)[0]
|
| 960 |
-
|
| 961 |
-
# compute previous image: x_t -> x_t-1
|
| 962 |
-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 963 |
-
|
| 964 |
-
if callback_on_step_end is not None:
|
| 965 |
-
callback_kwargs = {}
|
| 966 |
-
for k in callback_on_step_end_tensor_inputs:
|
| 967 |
-
callback_kwargs[k] = locals()[k]
|
| 968 |
-
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 969 |
-
|
| 970 |
-
latents = callback_outputs.pop("latents", latents)
|
| 971 |
-
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 972 |
-
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 973 |
-
|
| 974 |
-
# call the callback, if provided
|
| 975 |
-
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 976 |
-
progress_bar.update()
|
| 977 |
-
|
| 978 |
-
if XLA_AVAILABLE:
|
| 979 |
-
xm.mark_step()
|
| 980 |
-
|
| 981 |
-
if output_type == "latent":
|
| 982 |
-
image = latents
|
| 983 |
-
else:
|
| 984 |
-
latents = latents.to(self.vae.dtype)
|
| 985 |
-
torch_accelerator_module = getattr(torch, get_device(), torch.cuda)
|
| 986 |
-
oom_error = (
|
| 987 |
-
torch.OutOfMemoryError
|
| 988 |
-
if is_torch_version(">=", "2.5.0")
|
| 989 |
-
else torch_accelerator_module.OutOfMemoryError
|
| 990 |
-
)
|
| 991 |
-
try:
|
| 992 |
-
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
| 993 |
-
except oom_error as e:
|
| 994 |
-
warnings.warn(
|
| 995 |
-
f"{e}. \n"
|
| 996 |
-
f"Try to use VAE tiling for large images. For example: \n"
|
| 997 |
-
f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
|
| 998 |
)
|
| 999 |
-
|
| 1000 |
-
|
|
|
|
|
|
|
| 1001 |
|
| 1002 |
-
|
| 1003 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1004 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1005 |
# Offload all models
|
| 1006 |
self.maybe_free_model_hooks()
|
| 1007 |
|
|
|
|
| 700 |
return self._interrupt
|
| 701 |
|
| 702 |
@torch.no_grad()
|
|
|
|
| 703 |
def __call__(
|
| 704 |
self,
|
| 705 |
prompt: Union[str, List[str]] = None,
|
| 706 |
+
height: Optional[int] = None,
|
| 707 |
+
width: Optional[int] = None,
|
| 708 |
+
num_inference_steps: int = 28,
|
| 709 |
+
guidance_scale: float = 1.0,
|
|
|
|
| 710 |
num_images_per_prompt: Optional[int] = 1,
|
|
|
|
|
|
|
|
|
|
| 711 |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 712 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 713 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 714 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
|
|
|
|
|
|
| 715 |
output_type: Optional[str] = "pil",
|
| 716 |
return_dict: bool = True,
|
|
|
|
|
|
|
|
|
|
| 717 |
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 718 |
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 719 |
+
max_sequence_length: int = 256,
|
| 720 |
+
noise_type: str = "fresh", # 'fresh', 'ddim', 'fixed'
|
| 721 |
+
time_scale: float = 1000.0,
|
| 722 |
+
use_resolution_binning: bool = True,
|
| 723 |
+
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 724 |
if use_resolution_binning:
|
| 725 |
if self.transformer.config.sample_size == 128:
|
| 726 |
aspect_ratio_bin = ASPECT_RATIO_4096_BIN
|
|
|
|
| 734 |
raise ValueError("Invalid sample size")
|
| 735 |
orig_height, orig_width = height, width
|
| 736 |
height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
|
| 737 |
+
|
| 738 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 739 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 740 |
|
| 741 |
+
# 1. Check inputs. Raise error if not correct
|
| 742 |
self.check_inputs(
|
| 743 |
prompt,
|
| 744 |
height,
|
| 745 |
width,
|
| 746 |
+
prompt_embeds=prompt_embeds,
|
| 747 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 748 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
|
|
|
|
|
|
|
|
|
| 749 |
)
|
| 750 |
|
| 751 |
self._guidance_scale = guidance_scale
|
|
|
|
| 752 |
self._interrupt = False
|
| 753 |
|
| 754 |
+
# 2. Define call parameters
|
| 755 |
if prompt is not None and isinstance(prompt, str):
|
| 756 |
batch_size = 1
|
| 757 |
elif prompt is not None and isinstance(prompt, list):
|
|
|
|
| 760 |
batch_size = prompt_embeds.shape[0]
|
| 761 |
|
| 762 |
device = self._execution_device
|
|
|
|
| 763 |
|
|
|
|
| 764 |
(
|
| 765 |
prompt_embeds,
|
| 766 |
+
pooled_prompt_embeds,
|
| 767 |
+
_, _,
|
|
|
|
| 768 |
) = self.encode_prompt(
|
| 769 |
prompt,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 770 |
prompt_embeds=prompt_embeds,
|
| 771 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 772 |
+
device=device,
|
| 773 |
+
num_images_per_prompt=num_images_per_prompt,
|
|
|
|
| 774 |
max_sequence_length=max_sequence_length,
|
|
|
|
|
|
|
| 775 |
)
|
| 776 |
+
# 3. Prepare latents
|
| 777 |
+
num_channels_latents = self.transformer.config.in_channels
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 778 |
latents = self.prepare_latents(
|
| 779 |
batch_size * num_images_per_prompt,
|
| 780 |
+
num_channels_latents,
|
| 781 |
height,
|
| 782 |
width,
|
| 783 |
+
prompt_embeds.dtype,
|
| 784 |
device,
|
| 785 |
generator,
|
| 786 |
latents,
|
| 787 |
)
|
| 788 |
|
| 789 |
+
# 4. SiD sampling loop
|
| 790 |
+
# Initialize D_x
|
| 791 |
+
D_x = torch.zeros_like(latents).to(latents.device)
|
| 792 |
+
# Use fixed noise for now (can be extended as needed)
|
| 793 |
+
initial_latents = latents.clone()
|
| 794 |
+
for i in range(num_inference_steps):
|
| 795 |
+
if noise_type == "fresh":
|
| 796 |
+
noise = (
|
| 797 |
+
latents if i == 0 else torch.randn_like(latents).to(latents.device)
|
| 798 |
+
)
|
| 799 |
+
elif noise_type == "ddim":
|
| 800 |
+
noise = (
|
| 801 |
+
latents if i == 0 else ((latents - (1.0 - t) * D_x) / t).detach()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 802 |
)
|
| 803 |
+
elif noise_type == "fixed":
|
| 804 |
+
noise = initial_latents # Use the initial, unmodified latents
|
| 805 |
+
else:
|
| 806 |
+
raise ValueError(f"Unknown noise_type: {noise_type}")
|
| 807 |
|
| 808 |
+
# Compute t value, normalized to [0, 1]
|
| 809 |
+
init_timesteps = 999
|
| 810 |
+
scalar_t = float(init_timesteps) * (
|
| 811 |
+
1.0 - float(i) / float(num_inference_steps)
|
| 812 |
+
)
|
| 813 |
+
t_val = scalar_t / 999.0
|
| 814 |
+
t = torch.full(
|
| 815 |
+
(latents.shape[0],), t_val, device=latents.device, dtype=latents.dtype
|
| 816 |
+
)
|
| 817 |
+
t_flattern = t.flatten()
|
| 818 |
+
if t.numel() > 1:
|
| 819 |
+
t = t.view(-1, 1, 1, 1)
|
| 820 |
+
|
| 821 |
+
latents = (1.0 - t) * D_x + t * noise
|
| 822 |
+
latent_model_input = latents
|
| 823 |
+
|
| 824 |
+
flow_pred = self.transformer(
|
| 825 |
+
hidden_states=latent_model_input,
|
| 826 |
+
encoder_hidden_states=prompt_embeds,
|
| 827 |
+
# encoder_attention_mask=prompt_attention_mask,
|
| 828 |
+
pooled_projections=pooled_prompt_embeds,
|
| 829 |
+
timestep=time_scale * t_flattern,
|
| 830 |
+
return_dict=False,
|
| 831 |
+
)[0]
|
| 832 |
+
D_x = latents - (
|
| 833 |
+
t * flow_pred
|
| 834 |
+
if torch.numel(t) == 1
|
| 835 |
+
else t.view(-1, 1, 1, 1) * flow_pred
|
| 836 |
+
)
|
| 837 |
|
| 838 |
+
# 5. Decode latent to image
|
| 839 |
+
image = self.vae.decode(
|
| 840 |
+
(D_x / self.vae.config.scaling_factor),
|
| 841 |
+
return_dict=False,
|
| 842 |
+
)[0]
|
| 843 |
+
if use_resolution_binning:
|
| 844 |
+
image = self.image_processor.resize_and_crop_tensor(image, orig_height, orig_width)
|
| 845 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 846 |
# Offload all models
|
| 847 |
self.maybe_free_model_hooks()
|
| 848 |
|
sid/pipeline_sid_sd3.py
CHANGED
|
@@ -54,6 +54,7 @@ else:
|
|
| 54 |
|
| 55 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 56 |
|
|
|
|
| 57 |
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
| 58 |
def calculate_shift(
|
| 59 |
image_seq_len,
|
|
@@ -683,7 +684,8 @@ class SiDSD3Pipeline(
|
|
| 683 |
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 684 |
max_sequence_length: int = 256,
|
| 685 |
use_sd3_shift: bool = False,
|
| 686 |
-
noise_type: str =
|
|
|
|
| 687 |
):
|
| 688 |
height = height or self.default_sample_size * self.vae_scale_factor
|
| 689 |
width = width or self.default_sample_size * self.vae_scale_factor
|
|
@@ -749,25 +751,33 @@ class SiDSD3Pipeline(
|
|
| 749 |
# Use fixed noise for now (can be extended as needed)
|
| 750 |
initial_latents = latents.clone()
|
| 751 |
for i in range(num_inference_steps):
|
| 752 |
-
if noise_type ==
|
| 753 |
-
noise =
|
| 754 |
-
|
| 755 |
-
|
| 756 |
-
elif noise_type ==
|
|
|
|
|
|
|
|
|
|
|
|
|
| 757 |
noise = initial_latents # Use the initial, unmodified latents
|
| 758 |
else:
|
| 759 |
raise ValueError(f"Unknown noise_type: {noise_type}")
|
| 760 |
-
|
| 761 |
# Compute t value, normalized to [0, 1]
|
| 762 |
init_timesteps = 999
|
| 763 |
-
scalar_t = float(init_timesteps) * (
|
|
|
|
|
|
|
| 764 |
t_val = scalar_t / 999.0
|
| 765 |
# t_val = 1.0 - float(i) / float(num_inference_steps)
|
| 766 |
if use_sd3_shift:
|
| 767 |
shift = 3.0
|
| 768 |
t_val = shift * t_val / (1 + (shift - 1) * t_val)
|
| 769 |
-
|
| 770 |
-
t = torch.full(
|
|
|
|
|
|
|
| 771 |
t_flattern = t.flatten()
|
| 772 |
if t.numel() > 1:
|
| 773 |
t = t.view(-1, 1, 1, 1)
|
|
@@ -778,19 +788,28 @@ class SiDSD3Pipeline(
|
|
| 778 |
flow_pred = self.transformer(
|
| 779 |
hidden_states=latent_model_input,
|
| 780 |
encoder_hidden_states=prompt_embeds,
|
| 781 |
-
#encoder_attention_mask=prompt_attention_mask,
|
| 782 |
pooled_projections=pooled_prompt_embeds,
|
| 783 |
-
timestep=
|
| 784 |
return_dict=False,
|
| 785 |
)[0]
|
| 786 |
-
D_x = latents - (
|
| 787 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 788 |
# 5. Decode latent to image
|
| 789 |
-
image = self.vae.decode(
|
|
|
|
|
|
|
|
|
|
| 790 |
image = self.image_processor.postprocess(image, output_type=output_type)
|
|
|
|
|
|
|
| 791 |
|
| 792 |
# 6. Return output
|
| 793 |
if not return_dict:
|
| 794 |
return (image,)
|
| 795 |
-
|
| 796 |
-
return SiDPipelineOutput(images=image)
|
|
|
|
| 54 |
|
| 55 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 56 |
|
| 57 |
+
|
| 58 |
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
| 59 |
def calculate_shift(
|
| 60 |
image_seq_len,
|
|
|
|
| 684 |
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 685 |
max_sequence_length: int = 256,
|
| 686 |
use_sd3_shift: bool = False,
|
| 687 |
+
noise_type: str = "fresh", # 'fresh', 'ddim', 'fixed'
|
| 688 |
+
time_scale: float = 1000.0,
|
| 689 |
):
|
| 690 |
height = height or self.default_sample_size * self.vae_scale_factor
|
| 691 |
width = width or self.default_sample_size * self.vae_scale_factor
|
|
|
|
| 751 |
# Use fixed noise for now (can be extended as needed)
|
| 752 |
initial_latents = latents.clone()
|
| 753 |
for i in range(num_inference_steps):
|
| 754 |
+
if noise_type == "fresh":
|
| 755 |
+
noise = (
|
| 756 |
+
latents if i == 0 else torch.randn_like(latents).to(latents.device)
|
| 757 |
+
)
|
| 758 |
+
elif noise_type == "ddim":
|
| 759 |
+
noise = (
|
| 760 |
+
latents if i == 0 else ((latents - (1.0 - t) * D_x) / t).detach()
|
| 761 |
+
)
|
| 762 |
+
elif noise_type == "fixed":
|
| 763 |
noise = initial_latents # Use the initial, unmodified latents
|
| 764 |
else:
|
| 765 |
raise ValueError(f"Unknown noise_type: {noise_type}")
|
| 766 |
+
|
| 767 |
# Compute t value, normalized to [0, 1]
|
| 768 |
init_timesteps = 999
|
| 769 |
+
scalar_t = float(init_timesteps) * (
|
| 770 |
+
1.0 - float(i) / float(num_inference_steps)
|
| 771 |
+
)
|
| 772 |
t_val = scalar_t / 999.0
|
| 773 |
# t_val = 1.0 - float(i) / float(num_inference_steps)
|
| 774 |
if use_sd3_shift:
|
| 775 |
shift = 3.0
|
| 776 |
t_val = shift * t_val / (1 + (shift - 1) * t_val)
|
| 777 |
+
|
| 778 |
+
t = torch.full(
|
| 779 |
+
(latents.shape[0],), t_val, device=latents.device, dtype=latents.dtype
|
| 780 |
+
)
|
| 781 |
t_flattern = t.flatten()
|
| 782 |
if t.numel() > 1:
|
| 783 |
t = t.view(-1, 1, 1, 1)
|
|
|
|
| 788 |
flow_pred = self.transformer(
|
| 789 |
hidden_states=latent_model_input,
|
| 790 |
encoder_hidden_states=prompt_embeds,
|
| 791 |
+
# encoder_attention_mask=prompt_attention_mask,
|
| 792 |
pooled_projections=pooled_prompt_embeds,
|
| 793 |
+
timestep=time_scale * t_flattern,
|
| 794 |
return_dict=False,
|
| 795 |
)[0]
|
| 796 |
+
D_x = latents - (
|
| 797 |
+
t * flow_pred
|
| 798 |
+
if torch.numel(t) == 1
|
| 799 |
+
else t.view(-1, 1, 1, 1) * flow_pred
|
| 800 |
+
)
|
| 801 |
+
|
| 802 |
# 5. Decode latent to image
|
| 803 |
+
image = self.vae.decode(
|
| 804 |
+
(D_x / self.vae.config.scaling_factor) + self.vae.config.shift_factor,
|
| 805 |
+
return_dict=False,
|
| 806 |
+
)[0]
|
| 807 |
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 808 |
+
|
| 809 |
+
self.maybe_free_model_hooks()
|
| 810 |
|
| 811 |
# 6. Return output
|
| 812 |
if not return_dict:
|
| 813 |
return (image,)
|
| 814 |
+
|
| 815 |
+
return SiDPipelineOutput(images=image)
|