Open-Sora-Plan-v1.1.0 / docs /Train_And_Eval_CausalVideoVAE.md
LinB203
update
bab971b

A newer version of the Gradio SDK is available: 4.40.0

Upgrade

Training

To execute in the terminal: bash scripts/causalvae/train.sh

When using GAN loss for training, two backward propagations are required. However, when custom optimizers are implemented in PyTorch Lightning, it can lead to the training step count being doubled, meaning each training loop effectively results in two steps. This issue can make it counterintuitive when setting the training step count and the starting step count for the GAN loss.

Code Structure

CausalVideoVAE is located in the directory opensora/models/ae/videobase. The directory structure is as follows:

.
├── causal_vae
├── causal_vqvae
├── configuration_videobase.py
├── dataset_videobase.py
├── __init__.py
├── losses
├── modeling_videobase.py
├── modules
├── __pycache__
├── trainer_videobase.py
├── utils
└── vqvae

The casual_vae directory defines the overall structure of the CausalVideoVAE model, and the modules directory contains some of the required modules for the model, including CausalConv3D, ResnetBlock3D, Attention, etc. The losses directory includes GAN loss, Perception loss, and other content.

Configuration

Model training requires two key files: one is the config.json file, which configures the model structure, loss function, learning rate, etc. The other is the train.sh file, which configures the dataset, training steps, precision, etc.

Model Configuration File

Taking the release version model configuration file release.json as an example:

{
  "_class_name": "CausalVAEModel",
  "_diffusers_version": "0.27.2",
  "attn_resolutions": [],
  "decoder_attention": "AttnBlock3D",
  "decoder_conv_in": "CausalConv3d",
  "decoder_conv_out": "CausalConv3d",
  "decoder_mid_resnet": "ResnetBlock3D",
  "decoder_resnet_blocks": [
    "ResnetBlock3D",
    "ResnetBlock3D",
    "ResnetBlock3D",
    "ResnetBlock3D"
  ],
  "decoder_spatial_upsample": [
    "",
    "SpatialUpsample2x",
    "SpatialUpsample2x",
    "SpatialUpsample2x"
  ],
  "decoder_temporal_upsample": [
    "",
    "",
    "TimeUpsample2x",
    "TimeUpsample2x"
  ],
  "double_z": true,
  "dropout": 0.0,
  "embed_dim": 4,
  "encoder_attention": "AttnBlock3D",
  "encoder_conv_in": "CausalConv3d",
  "encoder_conv_out": "CausalConv3d",
  "encoder_mid_resnet": "ResnetBlock3D",
  "encoder_resnet_blocks": [
    "ResnetBlock3D",
    "ResnetBlock3D",
    "ResnetBlock3D",
    "ResnetBlock3D"
  ],
  "encoder_spatial_downsample": [
    "SpatialDownsample2x",
    "SpatialDownsample2x",
    "SpatialDownsample2x",
    ""
  ],
  "encoder_temporal_downsample": [
    "TimeDownsample2x",
    "TimeDownsample2x",
    "",
    ""
  ],
  "hidden_size": 128,
  "hidden_size_mult": [
    1,
    2,
    4,
    4
  ],
  "loss_params": {
    "disc_start": 2001,
    "disc_weight": 0.5,
    "kl_weight": 1e-06,
    "logvar_init": 0.0
  },
  "loss_type": "opensora.models.ae.videobase.losses.LPIPSWithDiscriminator",
  "lr": 1e-05,
  "num_res_blocks": 2,
  "q_conv": "CausalConv3d",
  "resolution": 256,
  "z_channels": 4
}

It configures the modules used in different layers of the encoder and decoder, as well as the loss. By changing the model configuration file, it is easy to train different model structures.

Training Script

The following is a description of the parameters for the train_causalvae.py:

Parameter Default Value Description
--exp_name "causalvae" The name of the experiment, used for the folder where results are saved.
--batch_size 1 The number of samples per training iteration.
--precision "bf16" The numerical precision type used for training.
--max_steps 100000 The maximum number of steps for the training process.
--save_steps 2000 The interval at which to save the model during training.
--output_dir "results/causalvae" The directory where training results are saved.
--video_path "/remote-home1/dataset/data_split_tt" The path where the video data is stored.
--video_num_frames 17 The number of frames per video.
--sample_rate 1 The sampling rate, indicating the number of video frames per second.
--dynamic_sample False Whether to use dynamic sampling.
--model_config "scripts/causalvae/288.yaml" The path to the model configuration file.
--n_nodes 1 The number of nodes used for training.
--devices 8 The number of devices used for training.
--resolution 256 The resolution of the videos.
--num_workers 8 The number of subprocesses used for data handling.
--resume_from_checkpoint None Resume training from a specified checkpoint.
--load_from_checkpoint None Load the model from a specified checkpoint.

Please ensure that the values provided for these parameters are appropriate for your training setup.

Evaluation

  1. Video Generation: The script scripts/causalvae/gen_video.sh in the repository is utilized for generating videos. For the parameters, please refer to the script itself.

  2. Video Evaluation: After video generation, You can evaluate the generated videos using the scripts/causalvae/eval.sh script. This evaluation script supports common metrics, including lpips, flolpips, ssim, psnr, and more.

Please note that you must generate the videos before executing the eval script. Additionally, it is essential to ensure that the video parameters used when generating the videos are consistent with those used during the evaluation.

How to Import a Trained Model

Our model class inherits from the configuration and model management classes of huggingface, supporting the download and loading of models from huggingface. It can also import models trained with pytorch lightning.

model = CausalVAEModel.from_pretrained(args.ckpt)
model = model.to(device)