File size: 7,709 Bytes
e7d5680
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
# Repo & Config Structure

## Repo Structure

```plaintext
Open-Sora
β”œβ”€β”€ README.md
β”œβ”€β”€ docs
β”‚   β”œβ”€β”€ acceleration.md            -> Acceleration & Speed benchmark
β”‚   β”œβ”€β”€ command.md                 -> Commands for training & inference
β”‚   β”œβ”€β”€ datasets.md                -> Datasets used in this project
β”‚   β”œβ”€β”€ structure.md               -> This file
β”‚   └── report_v1.md               -> Report for Open-Sora v1
β”œβ”€β”€ scripts
β”‚   β”œβ”€β”€ train.py                   -> diffusion training script
β”‚   └── inference.py               -> Report for Open-Sora v1
β”œβ”€β”€ configs                        -> Configs for training & inference
β”œβ”€β”€ opensora
β”‚   β”œβ”€β”€ __init__.py
β”‚   β”œβ”€β”€ registry.py                -> Registry helper
β”‚Β Β  β”œβ”€β”€ acceleration               -> Acceleration related code
β”‚Β Β  β”œβ”€β”€ dataset                    -> Dataset related code
β”‚Β Β  β”œβ”€β”€ models
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ layers                 -> Common layers
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ vae                    -> VAE as image encoder
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ text_encoder           -> Text encoder
β”‚Β Β  β”‚Β Β  β”‚Β Β  β”œβ”€β”€ classes.py         -> Class id encoder (inference only)
β”‚Β Β  β”‚Β Β  β”‚Β Β  β”œβ”€β”€ clip.py            -> CLIP encoder
β”‚Β Β  β”‚Β Β  β”‚Β Β  └── t5.py              -> T5 encoder
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ dit
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ latte
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ pixart
β”‚Β Β  β”‚Β Β  └── stdit                  -> Our STDiT related code
β”‚Β Β  β”œβ”€β”€ schedulers                 -> Diffusion shedulers
β”‚Β Β  β”‚Β Β  β”œβ”€β”€ iddpm                  -> IDDPM for training and inference
β”‚Β Β  β”‚   └── dpms                   -> DPM-Solver for fast inference
β”‚   └── utils
└── tools                          -> Tools for data processing and more
```

## Configs

Our config files follows [MMEgine](https://github.com/open-mmlab/mmengine). MMEngine will reads the config file (a `.py` file) and parse it into a dictionary-like object.

```plaintext
Open-Sora
└── configs                        -> Configs for training & inference
    β”œβ”€β”€ opensora                   -> STDiT related configs
    β”‚   β”œβ”€β”€ inference
    β”‚   β”‚   β”œβ”€β”€ 16x256x256.py      -> Sample videos 16 frames 256x256
    β”‚   β”‚   β”œβ”€β”€ 16x512x512.py      -> Sample videos 16 frames 512x512
    β”‚   β”‚   └── 64x512x512.py      -> Sample videos 64 frames 512x512
    β”‚   └── train
    β”‚       β”œβ”€β”€ 16x256x256.py      -> Train on videos 16 frames 256x256
    β”‚       β”œβ”€β”€ 16x256x256.py      -> Train on videos 16 frames 256x256
    β”‚       └── 64x512x512.py      -> Train on videos 64 frames 512x512
    β”œβ”€β”€ dit                        -> DiT related configs
 Β Β  β”‚Β Β  β”œβ”€β”€ inference
 Β Β  β”‚Β Β  β”‚Β Β  β”œβ”€β”€ 1x256x256-class.py -> Sample images with ckpts from DiT
 Β Β  β”‚Β Β  β”‚Β Β  β”œβ”€β”€ 1x256x256.py       -> Sample images with clip condition
 Β Β  β”‚Β Β  β”‚Β Β  └── 16x256x256.py      -> Sample videos
 Β Β  β”‚Β Β  └── train
 Β Β  β”‚Β Β    Β  β”œβ”€β”€ 1x256x256.py       -> Train on images with clip condition
 Β Β  β”‚Β Β   Β Β  └── 16x256x256.py      -> Train on videos
    β”œβ”€β”€ latte                      -> Latte related configs
    └── pixart                     -> PixArt related configs
```

## Inference config demos

To change the inference settings, you can directly modify the corresponding config file. Or you can pass arguments to overwrite the config file ([config_utils.py](/opensora/utils/config_utils.py)). To change sampling prompts, you should modify the `.txt` file passed to the `--prompt_path` argument.

```plaintext
--prompt_path ./assets/texts/t2v_samples.txt  -> prompt_path
--ckpt-path ./path/to/your/ckpt.pth           -> model["from_pretrained"]
```

The explanation of each field is provided below.

```python
# Define sampling size
num_frames = 64               # number of frames
fps = 24 // 2                 # frames per second (divided by 2 for frame_interval=2)
image_size = (512, 512)       # image size (height, width)

# Define model
model = dict(
    type="STDiT-XL/2",        # Select model type (STDiT-XL/2, DiT-XL/2, etc.)
    space_scale=1.0,          # (Optional) Space positional encoding scale (new height / old height)
    time_scale=2 / 3,         # (Optional) Time positional encoding scale (new frame_interval / old frame_interval)
    enable_flashattn=True,    # (Optional) Speed up training and inference with flash attention
    enable_layernorm_kernel=True, # (Optional) Speed up training and inference with fused kernel
    from_pretrained="PRETRAINED_MODEL",  # (Optional) Load from pretrained model
    no_temporal_pos_emb=True,  # (Optional) Disable temporal positional encoding (for image)
)
vae = dict(
    type="VideoAutoencoderKL", # Select VAE type
    from_pretrained="stabilityai/sd-vae-ft-ema", # Load from pretrained VAE
    micro_batch_size=128,      # VAE with micro batch size to save memory
)
text_encoder = dict(
    type="t5",                 # Select text encoder type (t5, clip)
    from_pretrained="./pretrained_models/t5_ckpts", # Load from pretrained text encoder
    model_max_length=120,      # Maximum length of input text
)
scheduler = dict(
    type="iddpm",              # Select scheduler type (iddpm, dpm-solver)
    num_sampling_steps=100,    # Number of sampling steps
    cfg_scale=7.0,             # hyper-parameter for classifier-free diffusion
)
dtype = "fp16"                 # Computation type (fp16, fp32, bf16)

# Other settings
batch_size = 1                 # batch size
seed = 42                      # random seed
prompt_path = "./assets/texts/t2v_samples.txt"  # path to prompt file
save_dir = "./samples"         # path to save samples
```

## Training config demos

```python
# Define sampling size
num_frames = 64
frame_interval = 2             # sample every 2 frames
image_size = (512, 512)

# Define dataset
root = None                    # root path to the dataset
data_path = "CSV_PATH"         # path to the csv file
use_image_transform = False    # True if training on images
num_workers = 4                # number of workers for dataloader

# Define acceleration
dtype = "bf16"                 # Computation type (fp16, bf16)
grad_checkpoint = True         # Use gradient checkpointing
plugin = "zero2"               # Plugin for distributed training (zero2, zero2-seq)
sp_size = 1                    # Sequence parallelism size (1 for no sequence parallelism)

# Define model
model = dict(
    type="STDiT-XL/2",
    space_scale=1.0,
    time_scale=2 / 3,
    from_pretrained="YOUR_PRETRAINED_MODEL",
    enable_flashattn=True,        # Enable flash attention
    enable_layernorm_kernel=True, # Enable layernorm kernel
)
vae = dict(
    type="VideoAutoencoderKL",
    from_pretrained="stabilityai/sd-vae-ft-ema",
    micro_batch_size=128,
)
text_encoder = dict(
    type="t5",
    from_pretrained="./pretrained_models/t5_ckpts",
    model_max_length=120,
    shardformer=True,           # Enable shardformer for T5 acceleration
)
scheduler = dict(
    type="iddpm",
    timestep_respacing="",      # Default 1000 timesteps
)

# Others
seed = 42
outputs = "outputs"             # path to save checkpoints
wandb = False                   # Use wandb for logging

epochs = 1000                   # number of epochs (just large enough, kill when satisfied)
log_every = 10
ckpt_every = 250
load = None                     # path to resume training

batch_size = 4
lr = 2e-5
grad_clip = 1.0                 # gradient clipping
```