Spaces:
Running
on
Zero
Running
on
Zero
| from diffusers.pipelines import FluxPipeline | |
| from diffusers.utils import logging | |
| from diffusers.pipelines.flux.pipeline_flux import logger | |
| from torch import Tensor | |
| def encode_images(pipeline: FluxPipeline, images: Tensor): | |
| images = pipeline.image_processor.preprocess(images) | |
| images = images.to(pipeline.device).to(pipeline.dtype) | |
| images = pipeline.vae.encode(images).latent_dist.sample() | |
| images = ( | |
| images - pipeline.vae.config.shift_factor | |
| ) * pipeline.vae.config.scaling_factor | |
| images_tokens = pipeline._pack_latents(images, *images.shape) | |
| images_ids = pipeline._prepare_latent_image_ids( | |
| images.shape[0], | |
| images.shape[2], | |
| images.shape[3], | |
| pipeline.device, | |
| pipeline.dtype, | |
| ) | |
| if images_tokens.shape[1] != images_ids.shape[0]: | |
| images_ids = pipeline._prepare_latent_image_ids( | |
| images.shape[0], | |
| images.shape[2] // 2, | |
| images.shape[3] // 2, | |
| pipeline.device, | |
| pipeline.dtype, | |
| ) | |
| return images_tokens, images_ids | |
| def prepare_text_input(pipeline: FluxPipeline, prompts, max_sequence_length=512): | |
| # Turn off warnings (CLIP overflow) | |
| logger.setLevel(logging.ERROR) | |
| ( | |
| prompt_embeds, | |
| pooled_prompt_embeds, | |
| text_ids, | |
| ) = pipeline.encode_prompt( | |
| prompt=prompts, | |
| prompt_2=None, | |
| prompt_embeds=None, | |
| pooled_prompt_embeds=None, | |
| device=pipeline.device, | |
| num_images_per_prompt=1, | |
| max_sequence_length=max_sequence_length, | |
| lora_scale=None, | |
| ) | |
| # Turn on warnings | |
| logger.setLevel(logging.WARNING) | |
| return prompt_embeds, pooled_prompt_embeds, text_ids | |