DreamO / dreamo /dreamo_pipeline.py
yanze's picture
init
92c3d7f
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, List, Optional, Union
import diffusers
import numpy as np
import torch
import torch.nn as nn
from diffusers import FluxPipeline
from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
from einops import repeat
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from dreamo.transformer import flux_transformer_forward
from dreamo.utils import convert_flux_lora_to_diffusers
diffusers.models.transformers.transformer_flux.FluxTransformer2DModel.forward = flux_transformer_forward
def get_task_embedding_idx(task):
return 0
class DreamOPipeline(FluxPipeline):
def __init__(self, scheduler, vae, text_encoder, tokenizer, text_encoder_2, tokenizer_2, transformer):
super().__init__(scheduler, vae, text_encoder, tokenizer, text_encoder_2, tokenizer_2, transformer)
self.t5_embedding = nn.Embedding(10, 4096)
self.task_embedding = nn.Embedding(2, 3072)
self.idx_embedding = nn.Embedding(10, 3072)
def load_dreamo_model(self, device, use_turbo=True):
hf_hub_download(repo_id='ByteDance/DreamO', filename='dreamo.safetensors', local_dir='models')
hf_hub_download(repo_id='ByteDance/DreamO', filename='dreamo_cfg_distill.safetensors', local_dir='models')
dreamo_lora = load_file('models/dreamo.safetensors')
cfg_distill_lora = load_file('models/dreamo_cfg_distill.safetensors')
self.t5_embedding.weight.data = dreamo_lora.pop('dreamo_t5_embedding.weight')[-10:]
self.task_embedding.weight.data = dreamo_lora.pop('dreamo_task_embedding.weight')
self.idx_embedding.weight.data = dreamo_lora.pop('dreamo_idx_embedding.weight')
self._prepare_t5()
dreamo_diffuser_lora = convert_flux_lora_to_diffusers(dreamo_lora)
cfg_diffuser_lora = convert_flux_lora_to_diffusers(cfg_distill_lora)
adapter_names = ['dreamo']
adapter_weights = [1]
self.load_lora_weights(dreamo_diffuser_lora, adapter_name='dreamo')
if cfg_diffuser_lora is not None:
self.load_lora_weights(cfg_diffuser_lora, adapter_name='cfg')
adapter_names.append('cfg')
adapter_weights.append(1)
if use_turbo:
self.load_lora_weights(
hf_hub_download(
"alimama-creative/FLUX.1-Turbo-Alpha", "diffusion_pytorch_model.safetensors", local_dir='models'
),
adapter_name='turbo',
)
adapter_names.append('turbo')
adapter_weights.append(1)
self.fuse_lora(adapter_names=adapter_names, adapter_weights=adapter_weights, lora_scale=1)
self.t5_embedding = self.t5_embedding.to(device)
self.task_embedding = self.task_embedding.to(device)
self.idx_embedding = self.idx_embedding.to(device)
def _prepare_t5(self):
self.text_encoder_2.resize_token_embeddings(len(self.tokenizer_2))
num_new_token = 10
new_token_list = [f"[ref#{i}]" for i in range(1, 10)] + ["[res]"]
self.tokenizer_2.add_tokens(new_token_list, special_tokens=False)
self.text_encoder_2.resize_token_embeddings(len(self.tokenizer_2))
input_embedding = self.text_encoder_2.get_input_embeddings().weight.data
input_embedding[-num_new_token:] = self.t5_embedding.weight.data
@staticmethod
def _prepare_latent_image_ids(batch_size, height, width, device, dtype, start_height=0, start_width=0):
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + start_height
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + start_width
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)
@staticmethod
def _prepare_style_latent_image_ids(batch_size, height, width, device, dtype, start_height=0, start_width=0):
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + start_height
latent_image_ids[..., 2] = latent_image_ids[..., 2] + start_width
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1)
latent_image_ids = latent_image_ids.reshape(
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
negative_prompt: Union[str, List[str]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
true_cfg_scale: float = 1.0,
true_cfg_start_step: int = 1,
true_cfg_end_step: int = 1,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 3.5,
neg_guidance_scale: float = 3.5,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
ref_conds=None,
first_step_guidance_scale=3.5,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
will be used instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
not greater than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
true_cfg_scale (`float`, *optional*, defaults to 1.0):
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 3.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
height,
width,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._current_timestep = None
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
)
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
(
prompt_embeds,
pooled_prompt_embeds,
text_ids,
) = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
if do_true_cfg:
(
negative_prompt_embeds,
negative_pooled_prompt_embeds,
_,
) = self.encode_prompt(
prompt=negative_prompt,
prompt_2=negative_prompt_2,
prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=negative_pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
latents, latent_image_ids = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 4.1 concat ref tokens to latent
origin_img_len = latents.shape[1]
embeddings = repeat(self.task_embedding.weight[1], "c -> n l c", n=batch_size, l=origin_img_len)
ref_latents = []
ref_latent_image_idss = []
start_height = height // 16
start_width = width // 16
for ref_cond in ref_conds:
img = ref_cond['img'] # [b, 3, h, w], range [-1, 1]
task = ref_cond['task']
idx = ref_cond['idx']
# encode ref with VAE
img = img.to(latents)
ref_latent = self.vae.encode(img).latent_dist.sample()
ref_latent = (ref_latent - self.vae.config.shift_factor) * self.vae.config.scaling_factor
cur_height = ref_latent.shape[2]
cur_width = ref_latent.shape[3]
ref_latent = self._pack_latents(ref_latent, batch_size, num_channels_latents, cur_height, cur_width)
ref_latent_image_ids = self._prepare_latent_image_ids(
batch_size, cur_height, cur_width, device, prompt_embeds.dtype, start_height, start_width
)
start_height += cur_height // 2
start_width += cur_width // 2
# prepare task_idx_embedding
task_idx = get_task_embedding_idx(task)
cur_task_embedding = repeat(
self.task_embedding.weight[task_idx], "c -> n l c", n=batch_size, l=ref_latent.shape[1]
)
cur_idx_embedding = repeat(
self.idx_embedding.weight[idx], "c -> n l c", n=batch_size, l=ref_latent.shape[1]
)
cur_embedding = cur_task_embedding + cur_idx_embedding
# concat ref to latent
embeddings = torch.cat([embeddings, cur_embedding], dim=1)
ref_latents.append(ref_latent)
ref_latent_image_idss.append(ref_latent_image_ids)
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# handle guidance
if self.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None
neg_guidance = torch.full([1], neg_guidance_scale, device=device, dtype=torch.float32)
neg_guidance = neg_guidance.expand(latents.shape[0])
first_step_guidance = torch.full([1], first_step_guidance_scale, device=device, dtype=torch.float32)
if self.joint_attention_kwargs is None:
self._joint_attention_kwargs = {}
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = self.transformer(
hidden_states=torch.cat((latents, *ref_latents), dim=1),
timestep=timestep / 1000,
guidance=guidance if i > 0 else first_step_guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=torch.cat((latent_image_ids, *ref_latent_image_idss), dim=1),
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
embeddings=embeddings,
)[0][:, :origin_img_len]
if do_true_cfg and i >= true_cfg_start_step and i < true_cfg_end_step:
neg_noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=neg_guidance,
pooled_projections=negative_pooled_prompt_embeds,
encoder_hidden_states=negative_prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype and torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
self._current_timestep = None
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return FluxPipelineOutput(images=image)