Diffusers documentation
PRX
PRX
PRX generates high-quality images from text using a simplified MMDIT architecture where text tokens don’t update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google’s T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing.
Available models
PRX offers multiple variants with different VAE configurations, each optimized for specific resolutions. Base models excel with detailed prompts, capturing complex compositions and subtle details. Fine-tuned models trained on the Alchemist dataset improve aesthetic quality, especially with simpler prompts.
| Model | Resolution | Fine-tuned | Distilled | Description | Suggested prompts | Suggested parameters | Recommended dtype |
|---|---|---|---|---|---|---|---|
Photoroom/prx-256-t2i | 256 | No | No | Base model pre-trained at 256 with Flux VAE | Works best with detailed prompts in natural language | 28 steps, cfg=5.0 | torch.bfloat16 |
Photoroom/prx-256-t2i-sft | 512 | Yes | No | Fine-tuned on the Alchemist dataset dataset with Flux VAE | Can handle less detailed prompts | 28 steps, cfg=5.0 | torch.bfloat16 |
Photoroom/prx-512-t2i | 512 | No | No | Base model pre-trained at 512 with Flux VAE | Works best with detailed prompts in natural language | 28 steps, cfg=5.0 | torch.bfloat16 |
Photoroom/prx-512-t2i-sft | 512 | Yes | No | Fine-tuned on the Alchemist dataset dataset with Flux VAE | Can handle less detailed prompts in natural language | 28 steps, cfg=5.0 | torch.bfloat16 |
Photoroom/prx-512-t2i-sft-distilled | 512 | Yes | Yes | 8-step distilled model from Photoroom/prx-512-t2i-sft | Can handle less detailed prompts in natural language | 8 steps, cfg=1.0 | torch.bfloat16 |
Photoroom/prx-512-t2i-dc-ae | 512 | No | No | Base model pre-trained at 512 with Deep Compression Autoencoder (DC-AE) | Works best with detailed prompts in natural language | 28 steps, cfg=5.0 | torch.bfloat16 |
Photoroom/prx-512-t2i-dc-ae-sft | 512 | Yes | No | Fine-tuned on the Alchemist dataset dataset with Deep Compression Autoencoder (DC-AE) | Can handle less detailed prompts in natural language | 28 steps, cfg=5.0 | torch.bfloat16 |
Photoroom/prx-512-t2i-dc-ae-sft-distilled | 512 | Yes | Yes | 8-step distilled model from Photoroom/prx-512-t2i-dc-ae-sft-distilled | Can handle less detailed prompts in natural language | 8 steps, cfg=1.0 | torch.bfloat16 |
Refer to this collection for more information.
Loading the pipeline
Load the pipeline with from_pretrained().
from diffusers.pipelines.prx import PRXPipeline
# Load pipeline - VAE and text encoder will be loaded from HuggingFace
pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16)
pipe.to("cuda")
prompt = "A front-facing portrait of a lion the golden savanna at sunset."
image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
image.save("prx_output.png")Manual Component Loading
Load components individually to customize the pipeline for instance to use quantized models.
import torch
from diffusers.pipelines.prx import PRXPipeline
from diffusers.models import AutoencoderKL, AutoencoderDC
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from transformers import T5GemmaModel, GemmaTokenizerFast
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from transformers import BitsAndBytesConfig as BitsAndBytesConfig
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
# Load transformer
transformer = PRXTransformer2DModel.from_pretrained(
"checkpoints/prx-512-t2i-sft",
subfolder="transformer",
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
)
# Load scheduler
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
"checkpoints/prx-512-t2i-sft", subfolder="scheduler"
)
# Load T5Gemma text encoder
t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2",
quantization_config=quant_config,
torch_dtype=torch.bfloat16)
text_encoder = t5gemma_model.encoder.to(dtype=torch.bfloat16)
tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2")
tokenizer.model_max_length = 256
# Load VAE - choose either Flux VAE or DC-AE
# Flux VAE
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev",
subfolder="vae",
quantization_config=quant_config,
torch_dtype=torch.bfloat16)
pipe = PRXPipeline(
transformer=transformer,
scheduler=scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae
)
pipe.to("cuda")Memory Optimization
For memory-constrained environments:
import torch
from diffusers.pipelines.prx import PRXPipeline
pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload() # Offload components to CPU when not in use
# Or use sequential CPU offload for even lower memory
pipe.enable_sequential_cpu_offload()PRXPipeline
class diffusers.PRXPipeline
< source >( transformer: PRXTransformer2DModel scheduler: FlowMatchEulerDiscreteScheduler text_encoder: T5GemmaEncoder tokenizer: typing.Union[transformers.models.t5.tokenization_t5_fast.T5TokenizerFast, transformers.models.gemma.tokenization_gemma_fast.GemmaTokenizerFast, transformers.models.auto.tokenization_auto.AutoTokenizer] vae: typing.Union[diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL, diffusers.models.autoencoders.autoencoder_dc.AutoencoderDC, NoneType] = None default_sample_size: typing.Optional[int] = 512 )
Parameters
- transformer (
PRXTransformer2DModel) — The PRX transformer model to denoise the encoded image latents. - scheduler (FlowMatchEulerDiscreteScheduler) —
A scheduler to be used in combination with
transformerto denoise the encoded image latents. - text_encoder (
T5GemmaEncoder) — Text encoder model for encoding prompts. - tokenizer ([
T5TokenizerFastorGemmaTokenizerFast]) — Tokenizer for the text encoder. - vae (AutoencoderKL or AutoencoderDC) — Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. Supports both AutoencoderKL (8x compression) and AutoencoderDC (32x compression).
Pipeline for text-to-image generation using PRX Transformer.
This model inherits from DiffusionPipeline. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
__call__
< source >( prompt: typing.Union[str, typing.List[str]] = None negative_prompt: str = '' height: typing.Optional[int] = None width: typing.Optional[int] = None num_inference_steps: int = 28 timesteps: typing.List[int] = None guidance_scale: float = 4.0 num_images_per_prompt: typing.Optional[int] = 1 generator: typing.Union[torch._C.Generator, typing.List[torch._C.Generator], NoneType] = None latents: typing.Optional[torch.Tensor] = None prompt_embeds: typing.Optional[torch.FloatTensor] = None negative_prompt_embeds: typing.Optional[torch.FloatTensor] = None prompt_attention_mask: typing.Optional[torch.BoolTensor] = None negative_prompt_attention_mask: typing.Optional[torch.BoolTensor] = None output_type: typing.Optional[str] = 'pil' return_dict: bool = True use_resolution_binning: bool = True callback_on_step_end: typing.Optional[typing.Callable[[int, int, typing.Dict], NoneType]] = None callback_on_step_end_tensor_inputs: typing.List[str] = ['latents'] ) → PRXPipelineOutput or tuple
Parameters
- prompt (
strorList[str], optional) — The prompt or prompts to guide the image generation. If not defined, one has to passprompt_embedsinstead. - negative_prompt (
str, optional, defaults to"") — The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored ifguidance_scaleis less than1). - height (
int, optional, defaults to self.transformer.config.sample_size * self.vae_scale_factor) — The height in pixels of the generated image. - width (
int, optional, defaults to self.transformer.config.sample_size * self.vae_scale_factor) — The width in pixels of the generated image. - num_inference_steps (
int, optional, defaults to 28) — The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. - timesteps (
List[int], optional) — Custom timesteps to use for the denoising process with schedulers which support atimestepsargument in theirset_timestepsmethod. If not defined, the default behavior whennum_inference_stepsis passed will be used. Must be in descending order. - guidance_scale (
float, optional, defaults to 4.0) — Guidance scale as defined in Classifier-Free Diffusion Guidance.guidance_scaleis defined aswof equation 2. of Imagen Paper. Guidance scale is enabled by settingguidance_scale > 1. Higher guidance scale encourages to generate images that are closely linked to the textprompt, usually at the expense of lower image quality. - num_images_per_prompt (
int, optional, defaults to 1) — The number of images to generate per prompt. - generator (
torch.GeneratororList[torch.Generator], optional) — One or a list of torch generator(s) to make generation deterministic. - latents (
torch.Tensor, optional) — Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will be generated by sampling using the supplied randomgenerator. - prompt_embeds (
torch.FloatTensor, optional) — Pre-generated text embeddings. Can be used to easily tweak text inputs, e.g. prompt weighting. If not provided, text embeddings will be generated frompromptinput argument. - negative_prompt_embeds (
torch.FloatTensor, optional) — Pre-generated negative text embeddings. Can be used to easily tweak text inputs, e.g. prompt weighting. If not provided andguidance_scale > 1, negative embeddings will be generated from an empty string. - prompt_attention_mask (
torch.BoolTensor, optional) — Pre-generated attention mask forprompt_embeds. If not provided, attention mask will be generated frompromptinput argument. - negative_prompt_attention_mask (
torch.BoolTensor, optional) — Pre-generated attention mask fornegative_prompt_embeds. If not provided andguidance_scale > 1, attention mask will be generated from an empty string. - output_type (
str, optional, defaults to"pil") — The output format of the generate image. Choose between PIL:PIL.Image.Imageornp.array. - return_dict (
bool, optional, defaults toTrue) — Whether or not to return a PRXPipelineOutput instead of a plain tuple. - use_resolution_binning (
bool, optional, defaults toTrue) — If set toTrue, the requested height and width are first mapped to the closest resolutions using predefined aspect ratio bins. After the produced latents are decoded into images, they are resized back to the requested resolution. Useful for generating non-square images at optimal resolutions. - callback_on_step_end (
Callable, optional) — A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments:callback_on_step_end(self, step, timestep, callback_kwargs).callback_kwargswill include a list of all tensors as specified bycallback_on_step_end_tensor_inputs. - callback_on_step_end_tensor_inputs (
List, optional) — The list of tensor inputs for thecallback_on_step_endfunction. The tensors specified in the list will be passed ascallback_kwargsargument. You will only be able to include tensors that are listed in the._callback_tensor_inputsattribute.
Returns
PRXPipelineOutput or tuple
PRXPipelineOutput if return_dict is
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
Function invoked when calling the pipeline for generation.
Examples:
>>> import torch
>>> from diffusers import PRXPipeline
>>> # Load pipeline with from_pretrained
>>> pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft")
>>> pipe.to("cuda")
>>> prompt = "A digital painting of a rusty, vintage tram on a sandy beach"
>>> image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
>>> image.save("prx_output.png")check_inputs
< source >( prompt: typing.Union[str, typing.List[str]] height: int width: int guidance_scale: float callback_on_step_end_tensor_inputs: typing.Optional[typing.List[str]] = None prompt_embeds: typing.Optional[torch.FloatTensor] = None negative_prompt_embeds: typing.Optional[torch.FloatTensor] = None )
Check that all inputs are in correct format.
encode_prompt
< source >( prompt: typing.Union[str, typing.List[str]] device: typing.Optional[torch.device] = None do_classifier_free_guidance: bool = True negative_prompt: str = '' num_images_per_prompt: int = 1 prompt_embeds: typing.Optional[torch.FloatTensor] = None negative_prompt_embeds: typing.Optional[torch.FloatTensor] = None prompt_attention_mask: typing.Optional[torch.BoolTensor] = None negative_prompt_attention_mask: typing.Optional[torch.BoolTensor] = None )
Encode text prompt using standard text encoder and tokenizer, or use precomputed embeddings.
get_default_resolution
< source >( ) → int
Returns
int
The default sample size (height/width) to use for generation.
Determine the default resolution based on the loaded VAE and config.
prepare_latents
< source >( batch_size: int num_channels_latents: int height: int width: int dtype: dtype device: device generator: typing.Optional[torch._C.Generator] = None latents: typing.Optional[torch.Tensor] = None )
Prepare initial latents for the diffusion process.
PRXPipelineOutput
class diffusers.pipelines.prx.PRXPipelineOutput
< source >( images: typing.Union[typing.List[PIL.Image.Image], numpy.ndarray] )
Output class for PRX pipelines.