Spaces:
Runtime error
Runtime error
File size: 6,812 Bytes
bab971b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
# Training
To execute in the terminal: `bash scripts/causalvae/train.sh`
> When using GAN loss for training, two backward propagations are required. However, when [custom optimizers](https://lightning.ai/docs/pytorch/stable/model/manual_optimization.html#use-multiple-optimizers-like-gans) are implemented in PyTorch Lightning, it can lead to the training step count being doubled, meaning each training loop effectively results in two steps. This issue can make it counterintuitive when setting the training step count and the starting step count for the GAN loss.
## Code Structure
CausalVideoVAE is located in the directory `opensora/models/ae/videobase`. The directory structure is as follows:
```
.
├── causal_vae
├── causal_vqvae
├── configuration_videobase.py
├── dataset_videobase.py
├── __init__.py
├── losses
├── modeling_videobase.py
├── modules
├── __pycache__
├── trainer_videobase.py
├── utils
└── vqvae
```
The `casual_vae` directory defines the overall structure of the CausalVideoVAE model, and the `modules` directory contains some of the required modules for the model, including **CausalConv3D**, **ResnetBlock3D**, **Attention**, etc. The `losses` directory includes **GAN loss**, **Perception loss**, and other content.
## Configuration
Model training requires two key files: one is the `config.json` file, which configures the model structure, loss function, learning rate, etc. The other is the `train.sh` file, which configures the dataset, training steps, precision, etc.
### Model Configuration File
Taking the release version model configuration file `release.json` as an example:
```json
{
"_class_name": "CausalVAEModel",
"_diffusers_version": "0.27.2",
"attn_resolutions": [],
"decoder_attention": "AttnBlock3D",
"decoder_conv_in": "CausalConv3d",
"decoder_conv_out": "CausalConv3d",
"decoder_mid_resnet": "ResnetBlock3D",
"decoder_resnet_blocks": [
"ResnetBlock3D",
"ResnetBlock3D",
"ResnetBlock3D",
"ResnetBlock3D"
],
"decoder_spatial_upsample": [
"",
"SpatialUpsample2x",
"SpatialUpsample2x",
"SpatialUpsample2x"
],
"decoder_temporal_upsample": [
"",
"",
"TimeUpsample2x",
"TimeUpsample2x"
],
"double_z": true,
"dropout": 0.0,
"embed_dim": 4,
"encoder_attention": "AttnBlock3D",
"encoder_conv_in": "CausalConv3d",
"encoder_conv_out": "CausalConv3d",
"encoder_mid_resnet": "ResnetBlock3D",
"encoder_resnet_blocks": [
"ResnetBlock3D",
"ResnetBlock3D",
"ResnetBlock3D",
"ResnetBlock3D"
],
"encoder_spatial_downsample": [
"SpatialDownsample2x",
"SpatialDownsample2x",
"SpatialDownsample2x",
""
],
"encoder_temporal_downsample": [
"TimeDownsample2x",
"TimeDownsample2x",
"",
""
],
"hidden_size": 128,
"hidden_size_mult": [
1,
2,
4,
4
],
"loss_params": {
"disc_start": 2001,
"disc_weight": 0.5,
"kl_weight": 1e-06,
"logvar_init": 0.0
},
"loss_type": "opensora.models.ae.videobase.losses.LPIPSWithDiscriminator",
"lr": 1e-05,
"num_res_blocks": 2,
"q_conv": "CausalConv3d",
"resolution": 256,
"z_channels": 4
}
```
It configures the modules used in different layers of the encoder and decoder, as well as the loss. By changing the model configuration file, it is easy to train different model structures.
### Training Script
The following is a description of the parameters for the `train_causalvae.py`:
| Parameter | Default Value | Description |
|-----------------------------|-----------------|--------------------------------------------------------|
| `--exp_name` | "causalvae" | The name of the experiment, used for the folder where results are saved. |
| `--batch_size` | 1 | The number of samples per training iteration. |
| `--precision` | "bf16" | The numerical precision type used for training. |
| `--max_steps` | 100000 | The maximum number of steps for the training process. |
| `--save_steps` | 2000 | The interval at which to save the model during training. |
| `--output_dir` | "results/causalvae" | The directory where training results are saved. |
| `--video_path` | "/remote-home1/dataset/data_split_tt" | The path where the video data is stored. |
| `--video_num_frames` | 17 | The number of frames per video. |
| `--sample_rate` | 1 | The sampling rate, indicating the number of video frames per second. |
| `--dynamic_sample` | False | Whether to use dynamic sampling. |
| `--model_config` | "scripts/causalvae/288.yaml" | The path to the model configuration file. |
| `--n_nodes` | 1 | The number of nodes used for training. |
| `--devices` | 8 | The number of devices used for training. |
| `--resolution` | 256 | The resolution of the videos. |
| `--num_workers` | 8 | The number of subprocesses used for data handling. |
| `--resume_from_checkpoint` | None | Resume training from a specified checkpoint. |
| `--load_from_checkpoint` | None | Load the model from a specified checkpoint. |
Please ensure that the values provided for these parameters are appropriate for your training setup.
# Evaluation
1. Video Generation:
The script `scripts/causalvae/gen_video.sh` in the repository is utilized for generating videos. For the parameters, please refer to the script itself.
2. Video Evaluation:
After video generation, You can evaluate the generated videos using the `scripts/causalvae/eval.sh` script. This evaluation script supports common metrics, including lpips, flolpips, ssim, psnr, and more.
> Please note that you must generate the videos before executing the eval script. Additionally, it is essential to ensure that the video parameters used when generating the videos are consistent with those used during the evaluation.
# How to Import a Trained Model
Our model class inherits from the configuration and model management classes of huggingface, supporting the download and loading of models from huggingface. It can also import models trained with pytorch lightning.
```
model = CausalVAEModel.from_pretrained(args.ckpt)
model = model.to(device)
```
|