nvan15's picture
Batch upload part 20
dad14e4 verified

Training for FLUX

Table of Contents

Environment Setup

  1. Create and activate a new conda environment:

    conda create -n omini python=3.10
    conda activate omini
    
  2. Install required packages:

    pip install -r requirements.txt
    

Dataset Preparation

  1. Download Subject200K dataset for subject-driven generation:

    bash train/script/data_download/data_download1.sh
    
  2. Download text-to-image-2M dataset for spatial alignment control tasks:

    bash train/script/data_download/data_download2.sh
    

    Note: By default, only a few files will be downloaded. You can edit data_download2.sh to download more data, and update the config file accordingly.

Quick Start

Use these scripts to start training immediately:

  1. Subject-driven generation:

    bash train/script/train_subject.sh
    
  2. Spatial control tasks (Canny-to-image, colorization, depth map, etc.):

    bash train/script/train_spatial_alignment.sh
    
  3. Multi-condition training:

    bash train/script/train_multi_condition.sh
    
  4. Feature reuse (OminiControl2):

    bash train/script/train_feature_reuse.sh
    
  5. Compact token representation (OminiControl2):

    bash train/script/train_compact_token_representation.sh
    
  6. Token integration (OminiControl2):

    bash train/script/train_token_intergration.sh
    

Basic Training

Tasks from OminiControl

arXiv

  1. Subject-driven generation:

    bash train/script/train_subject.sh
    
  2. Spatial control tasks (using canny-to-image as example):

    bash train/script/train_spatial_alignment.sh
    
    Supported tasks
    • Canny edge to image (canny)
    • Image colorization (coloring)
    • Image deblurring (deblurring)
    • Depth map to image (depth)
    • Image to depth map (depth_pred)
    • Image inpainting (fill)
    • Super resolution (sr)

    🌟 Change the condition_type parameter in the config file to switch between tasks.

Note: Check the script files (train/script/) and config files (train/configs/) for WanDB and GPU settings.

Creating Your Own Task

You can create a custom task by building a new dataset and modifying the test code:

  1. Create a custom dataset: Your custom dataset should follow the format of Subject200KDataset in omini/train_flux/train_subject.py. Each sample should contain:

    • Image: the target image (image)
    • Text: description of the image (description)
    • Conditions: image conditions for generation
    • Position delta:
      • Use position_delta = (0, 0) to align the condition with the generated image
      • Use position_delta = (0, -a) to separate them (a = condition width / 16)

    Explanation:
    The model places both the condition and generated image in a shared coordinate system. position_delta shifts the condition image in this space.

    Each unit equals one patch (16 pixels). For a 512px-wide condition image (32 patches), position_delta = (0, -32) moves it fully to the left.

    This controls whether conditions and generated images share space or appear side-by-side.

  2. Modify the test code: Define test_function() in train_custom.py. Refer to the function in train_subject.py for examples. Make sure to keep the position_delta parameter consistent with your dataset.

Training Configuration

Batch Size

We recommend a batch size of 1 for stable training. And you can set accumulate_grad_batches to n to simulate a batch size of n.

Optimizer

The default optimizer is Prodigy. To use AdamW instead, modify the config file:

optimizer:
  type: AdamW
  lr: 1e-4
  weight_decay: 0.001

LoRA Configuration

Default LoRA rank is 4. Increase it for complex tasks (keep r and lora_alpha parameters the same):

lora_config:
  r: 128
  lora_alpha: 128

Trainable Modules

The target_modules parameter uses regex patterns to specify which modules to train. See PEFT Documentation for details.

Default configuration trains all modules affecting image tokens:

target_modules: "(.*x_embedder|.*(?<!single_)transformer_blocks\\.[0-9]+\\.norm1\\.linear|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_k|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_q|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_v|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_out\\.0|.*(?<!single_)transformer_blocks\\.[0-9]+\\.ff\\.net\\.2|.*single_transformer_blocks\\.[0-9]+\\.norm\\.linear|.*single_transformer_blocks\\.[0-9]+\\.proj_mlp|.*single_transformer_blocks\\.[0-9]+\\.proj_out|.*single_transformer_blocks\\.[0-9]+\\.attn.to_k|.*single_transformer_blocks\\.[0-9]+\\.attn.to_q|.*single_transformer_blocks\\.[0-9]+\\.attn.to_v|.*single_transformer_blocks\\.[0-9]+\\.attn.to_out)"

To train only attention components (to_q, to_k, to_v), use:

target_modules: "(.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_k|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_q|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_v|.*single_transformer_blocks\\.[0-9]+\\.attn.to_k|.*single_transformer_blocks\\.[0-9]+\\.attn.to_q|.*single_transformer_blocks\\.[0-9]+\\.attn.to_v)"

Advanced Training

Multi-condition

A basic multi-condition implementation is available in train_multi_condition.py:

bash train/script/train_multi_condition.sh

Efficient Generation (OminiControl2)

arXiv

OminiControl2 introduces techniques to improve generation efficiency:

Feature Reuse (KV-Cache)

  1. Enable independent_condition in the config file during training:

    model:
      independent_condition: true
    
  2. During inference, set kv_cache = True in the generate function to speed up generation.

Example:

bash train/script/train_feature_reuse.sh

Note: Feature reuse speeds up generation but may slightly reduce performance and increase training time.

Compact Encoding Representation

Reduce the condition image resolution and use position_scale to align it with the output image:

train:
  dataset:
    condition_size: 
-     - 512
-     - 512
+     - 256
+     - 256
+   position_scale: 2
    target_size: 
      - 512
      - 512

Example:

bash train/script/train_compact_token_representation.sh

Token Integration (for Fill task)

Further reduce tokens by merging condition and generation tokens into a unified sequence. (Refer to the paper for details.)

Example:

bash train/script/train_token_intergration.sh

Citation

If you find this code useful, please cite our papers:

@article{tan2024ominicontrol,
  title={OminiControl: Minimal and Universal Control for Diffusion Transformer},
  author={Tan, Zhenxiong and Liu, Songhua and Yang, Xingyi and Xue, Qiaochu and Wang, Xinchao},
  journal={arXiv preprint arXiv:2411.15098},
  year={2024}
}

@article{tan2025ominicontrol2,
  title={OminiControl2: Efficient Conditioning for Diffusion Transformers},
  author={Tan, Zhenxiong and Xue, Qiaochu and Yang, Xingyi and Liu, Songhua and Wang, Xinchao},
  journal={arXiv preprint arXiv:2503.08280},
  year={2025}
}