|
#!/bin/bash |
|
|
|
|
|
|
|
|
|
|
|
|
|
GPUS=8 |
|
TASK=t2i |
|
SHARD_STRATEGY=zero2 |
|
MODEL_NAME=pyramid_flux |
|
MODEL_PATH=/PATH/pyramid-flow-miniflux |
|
VARIANT=diffusion_transformer_image |
|
|
|
OUTPUT_DIR=/PATH/output_dir |
|
NUM_FRAMES=8 |
|
BATCH_SIZE=4 |
|
RESOLUTION="768p" |
|
ANNO_FILE=annotation/image_text.jsonl |
|
|
|
|
|
torchrun --nproc_per_node $GPUS \ |
|
train/train_pyramid_flow.py \ |
|
--num_workers 8 \ |
|
--task $TASK \ |
|
--use_fsdp \ |
|
--fsdp_shard_strategy $SHARD_STRATEGY \ |
|
--use_flash_attn \ |
|
--load_text_encoder \ |
|
--load_vae \ |
|
--model_name $MODEL_NAME \ |
|
--model_path $MODEL_PATH \ |
|
--model_dtype bf16 \ |
|
--model_variant $VARIANT \ |
|
--schedule_shift 1.0 \ |
|
--gradient_accumulation_steps 1 \ |
|
--output_dir $OUTPUT_DIR \ |
|
--batch_size $BATCH_SIZE \ |
|
--max_frames $NUM_FRAMES \ |
|
--resolution $RESOLUTION \ |
|
--anno_file $ANNO_FILE \ |
|
--frame_per_unit 1 \ |
|
--lr_scheduler constant_with_warmup \ |
|
--opt adamw \ |
|
--opt_beta1 0.9 \ |
|
--opt_beta2 0.95 \ |
|
--seed 42 \ |
|
--weight_decay 1e-4 \ |
|
--clip_grad 1.0 \ |
|
--lr 1e-4 \ |
|
--warmup_steps 1000 \ |
|
--epochs 20 \ |
|
--iters_per_epoch 2000 \ |
|
--report_to tensorboard \ |
|
--print_freq 40 \ |
|
--save_ckpt_freq 1 |