Diffusers documentation

PRX

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v0.35.1).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

PRX

PRX generates high-quality images from text using a simplified MMDIT architecture where text tokens don’t update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google’s T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing.

Available models

PRX offers multiple variants with different VAE configurations, each optimized for specific resolutions. Base models excel with detailed prompts, capturing complex compositions and subtle details. Fine-tuned models trained on the Alchemist dataset improve aesthetic quality, especially with simpler prompts.

Model Resolution Fine-tuned Distilled Description Suggested prompts Suggested parameters Recommended dtype
Photoroom/prx-256-t2i 256 No No Base model pre-trained at 256 with Flux VAE Works best with detailed prompts in natural language 28 steps, cfg=5.0 torch.bfloat16
Photoroom/prx-256-t2i-sft 512 Yes No Fine-tuned on the Alchemist dataset dataset with Flux VAE Can handle less detailed prompts 28 steps, cfg=5.0 torch.bfloat16
Photoroom/prx-512-t2i 512 No No Base model pre-trained at 512 with Flux VAE Works best with detailed prompts in natural language 28 steps, cfg=5.0 torch.bfloat16
Photoroom/prx-512-t2i-sft 512 Yes No Fine-tuned on the Alchemist dataset dataset with Flux VAE Can handle less detailed prompts in natural language 28 steps, cfg=5.0 torch.bfloat16
Photoroom/prx-512-t2i-sft-distilled 512 Yes Yes 8-step distilled model from Photoroom/prx-512-t2i-sft Can handle less detailed prompts in natural language 8 steps, cfg=1.0 torch.bfloat16
Photoroom/prx-512-t2i-dc-ae 512 No No Base model pre-trained at 512 with Deep Compression Autoencoder (DC-AE) Works best with detailed prompts in natural language 28 steps, cfg=5.0 torch.bfloat16
Photoroom/prx-512-t2i-dc-ae-sft 512 Yes No Fine-tuned on the Alchemist dataset dataset with Deep Compression Autoencoder (DC-AE) Can handle less detailed prompts in natural language 28 steps, cfg=5.0 torch.bfloat16
Photoroom/prx-512-t2i-dc-ae-sft-distilled 512 Yes Yes 8-step distilled model from Photoroom/prx-512-t2i-dc-ae-sft-distilled Can handle less detailed prompts in natural language 8 steps, cfg=1.0 torch.bfloat16

Refer to this collection for more information.

Loading the pipeline

Load the pipeline with from_pretrained().

from diffusers.pipelines.prx import PRXPipeline

# Load pipeline - VAE and text encoder will be loaded from HuggingFace
pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16)
pipe.to("cuda")

prompt = "A front-facing portrait of a lion the golden savanna at sunset."
image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
image.save("prx_output.png")

Manual Component Loading

Load components individually to customize the pipeline for instance to use quantized models.

import torch
from diffusers.pipelines.prx import PRXPipeline
from diffusers.models import AutoencoderKL, AutoencoderDC
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from transformers import T5GemmaModel, GemmaTokenizerFast
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from transformers import BitsAndBytesConfig as BitsAndBytesConfig

quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
# Load transformer
transformer = PRXTransformer2DModel.from_pretrained(
    "checkpoints/prx-512-t2i-sft",
    subfolder="transformer",
    quantization_config=quant_config,
    torch_dtype=torch.bfloat16,
)

# Load scheduler
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
    "checkpoints/prx-512-t2i-sft", subfolder="scheduler"
)

# Load T5Gemma text encoder
t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2",
                                            quantization_config=quant_config,
                                            torch_dtype=torch.bfloat16)
text_encoder = t5gemma_model.encoder.to(dtype=torch.bfloat16)
tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2")
tokenizer.model_max_length = 256

# Load VAE - choose either Flux VAE or DC-AE
# Flux VAE
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev",
                                    subfolder="vae",
                                    quantization_config=quant_config,
                                    torch_dtype=torch.bfloat16)

pipe = PRXPipeline(
    transformer=transformer,
    scheduler=scheduler,
    text_encoder=text_encoder,
    tokenizer=tokenizer,
    vae=vae
)
pipe.to("cuda")

Memory Optimization

For memory-constrained environments:

import torch
from diffusers.pipelines.prx import PRXPipeline

pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload()  # Offload components to CPU when not in use

# Or use sequential CPU offload for even lower memory
pipe.enable_sequential_cpu_offload()

PRXPipeline

class diffusers.PRXPipeline

< >

( transformer: PRXTransformer2DModel scheduler: FlowMatchEulerDiscreteScheduler text_encoder: T5GemmaEncoder tokenizer: typing.Union[transformers.models.t5.tokenization_t5_fast.T5TokenizerFast, transformers.models.gemma.tokenization_gemma_fast.GemmaTokenizerFast, transformers.models.auto.tokenization_auto.AutoTokenizer] vae: typing.Union[diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL, diffusers.models.autoencoders.autoencoder_dc.AutoencoderDC, NoneType] = None default_sample_size: typing.Optional[int] = 512 )

Parameters

  • transformer (PRXTransformer2DModel) — The PRX transformer model to denoise the encoded image latents.
  • scheduler (FlowMatchEulerDiscreteScheduler) — A scheduler to be used in combination with transformer to denoise the encoded image latents.
  • text_encoder (T5GemmaEncoder) — Text encoder model for encoding prompts.
  • tokenizer ([T5TokenizerFast or GemmaTokenizerFast]) — Tokenizer for the text encoder.
  • vae (AutoencoderKL or AutoencoderDC) — Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. Supports both AutoencoderKL (8x compression) and AutoencoderDC (32x compression).

Pipeline for text-to-image generation using PRX Transformer.

This model inherits from DiffusionPipeline. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

__call__

< >

( prompt: typing.Union[str, typing.List[str]] = None negative_prompt: str = '' height: typing.Optional[int] = None width: typing.Optional[int] = None num_inference_steps: int = 28 timesteps: typing.List[int] = None guidance_scale: float = 4.0 num_images_per_prompt: typing.Optional[int] = 1 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.FloatTensor] = None negative_prompt_embeds: typing.Optional[torch.FloatTensor] = None prompt_attention_mask: typing.Optional[torch.BoolTensor] = None negative_prompt_attention_mask: typing.Optional[torch.BoolTensor] = None output_type: typing.Optional[str] = 'pil' return_dict: bool = True use_resolution_binning: bool = True callback_on_step_end: typing.Optional[typing.Callable[[int, int, typing.Dict], NoneType]] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] ) PRXPipelineOutput or tuple

Parameters

  • prompt (str or List[str], optional) — The prompt or prompts to guide the image generation. If not defined, one has to pass prompt_embeds instead.
  • negative_prompt (str, optional, defaults to "") — The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored if guidance_scale is less than 1).
  • height (int, optional, defaults to self.transformer.config.sample_size * self.vae_scale_factor) — The height in pixels of the generated image.
  • width (int, optional, defaults to self.transformer.config.sample_size * self.vae_scale_factor) — The width in pixels of the generated image.
  • num_inference_steps (int, optional, defaults to 28) — The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference.
  • timesteps (List[int], optional) — Custom timesteps to use for the denoising process with schedulers which support a timesteps argument in their set_timesteps method. If not defined, the default behavior when num_inference_steps is passed will be used. Must be in descending order.
  • guidance_scale (float, optional, defaults to 4.0) — Guidance scale as defined in Classifier-Free Diffusion Guidance. guidance_scale is defined as w of equation 2. of Imagen Paper. Guidance scale is enabled by setting guidance_scale > 1. Higher guidance scale encourages to generate images that are closely linked to the text prompt, usually at the expense of lower image quality.
  • num_images_per_prompt (int, optional, defaults to 1) — The number of images to generate per prompt.
  • generator (torch.Generator or List[torch.Generator], optional) — One or a list of torch generator(s) to make generation deterministic.
  • latents (torch.Tensor, optional) — Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will be generated by sampling using the supplied random generator.
  • prompt_embeds (torch.FloatTensor, optional) — Pre-generated text embeddings. Can be used to easily tweak text inputs, e.g. prompt weighting. If not provided, text embeddings will be generated from prompt input argument.
  • negative_prompt_embeds (torch.FloatTensor, optional) — Pre-generated negative text embeddings. Can be used to easily tweak text inputs, e.g. prompt weighting. If not provided and guidance_scale > 1, negative embeddings will be generated from an empty string.
  • prompt_attention_mask (torch.BoolTensor, optional) — Pre-generated attention mask for prompt_embeds. If not provided, attention mask will be generated from prompt input argument.
  • negative_prompt_attention_mask (torch.BoolTensor, optional) — Pre-generated attention mask for negative_prompt_embeds. If not provided and guidance_scale > 1, attention mask will be generated from an empty string.
  • output_type (str, optional, defaults to "pil") — The output format of the generate image. Choose between PIL: PIL.Image.Image or np.array.
  • return_dict (bool, optional, defaults to True) — Whether or not to return a PRXPipelineOutput instead of a plain tuple.
  • use_resolution_binning (bool, optional, defaults to True) — If set to True, the requested height and width are first mapped to the closest resolutions using predefined aspect ratio bins. After the produced latents are decoded into images, they are resized back to the requested resolution. Useful for generating non-square images at optimal resolutions.
  • callback_on_step_end (Callable, optional) — A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: callback_on_step_end(self, step, timestep, callback_kwargs). callback_kwargs will include a list of all tensors as specified by callback_on_step_end_tensor_inputs.
  • callback_on_step_end_tensor_inputs (List, optional) — The list of tensor inputs for the callback_on_step_end function. The tensors specified in the list will be passed as callback_kwargs argument. You will only be able to include tensors that are listed in the ._callback_tensor_inputs attribute.

Returns

PRXPipelineOutput or tuple

PRXPipelineOutput if return_dict is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.

Function invoked when calling the pipeline for generation.

Examples:

>>> import torch
>>> from diffusers import PRXPipeline

>>> # Load pipeline with from_pretrained
>>> pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft")
>>> pipe.to("cuda")

>>> prompt = "A digital painting of a rusty, vintage tram on a sandy beach"
>>> image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
>>> image.save("prx_output.png")

check_inputs

< >

( prompt: typing.Union[str, typing.List[str]] height: int width: int guidance_scale: float callback_on_step_end_tensor_inputs: typing.Optional[typing.List[str]] = None prompt_embeds: typing.Optional[torch.FloatTensor] = None negative_prompt_embeds: typing.Optional[torch.FloatTensor] = None )

Check that all inputs are in correct format.

encode_prompt

< >

( prompt: typing.Union[str, typing.List[str]] device: typing.Optional[torch.device] = None do_classifier_free_guidance: bool = True negative_prompt: str = '' num_images_per_prompt: int = 1 prompt_embeds: typing.Optional[torch.FloatTensor] = None negative_prompt_embeds: typing.Optional[torch.FloatTensor] = None prompt_attention_mask: typing.Optional[torch.BoolTensor] = None negative_prompt_attention_mask: typing.Optional[torch.BoolTensor] = None )

Encode text prompt using standard text encoder and tokenizer, or use precomputed embeddings.

get_default_resolution

< >

( ) int

Returns

int

The default sample size (height/width) to use for generation.

Determine the default resolution based on the loaded VAE and config.

prepare_latents

< >

( batch_size: int num_channels_latents: int height: int width: int dtype: dtype device: device generator: typing.Optional[torch._C.Generator] = None latents: typing.Optional[torch.Tensor] = None )

Prepare initial latents for the diffusion process.

PRXPipelineOutput

class diffusers.pipelines.prx.PRXPipelineOutput

< >

( images: typing.Union[typing.List[PIL.Image.Image], numpy.ndarray] )

Parameters

  • images (List[PIL.Image.Image] or np.ndarray) — List of denoised PIL images of length batch_size or numpy array of shape (batch_size, height, width, num_channels). PIL images or numpy array present the denoised images of the diffusion pipeline.

Output class for PRX pipelines.

Update on GitHub