Spaces:
Runtime error
Runtime error
| import sys | |
| sys.path.append(".") | |
| import torch | |
| import random | |
| import numpy as np | |
| from opensora.models.ae.videobase import ( | |
| CausalVAEModel, | |
| ) | |
| from torch.utils.data import DataLoader | |
| from opensora.models.ae.videobase.dataset_videobase import VideoDataset | |
| import argparse | |
| from transformers import HfArgumentParser | |
| from dataclasses import dataclass, field, asdict | |
| import torch.distributed as dist | |
| import os | |
| import pytorch_lightning as pl | |
| from pytorch_lightning.loggers import WandbLogger | |
| from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor | |
| class TrainingArguments: | |
| exp_name: str = field(default="causalvae") | |
| batch_size: int = field(default=1) | |
| precision: str = field(default="bf16") | |
| max_steps: int = field(default=100000) | |
| save_steps: int = field(default=2000) | |
| output_dir: str = field(default="results/causalvae") | |
| video_path: str = field(default="/remote-home1/dataset/data_split_tt") | |
| video_num_frames: int = field(default=17) | |
| sample_rate: int = field(default=1) | |
| dynamic_sample: bool = field(default=False) | |
| model_config: str = field(default="scripts/causalvae/288.yaml") | |
| n_nodes: int = field(default=1) | |
| devices: int = field(default=8) | |
| resolution: int = field(default=64) | |
| num_workers: int = field(default=8) | |
| resume_from_checkpoint: str = field(default=None) | |
| def set_seed(seed=1006): | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| def load_callbacks_and_logger(args): | |
| checkpoint_callback = ModelCheckpoint( | |
| dirpath=args.output_dir, | |
| filename="model-{epoch:02d}-{step}", | |
| every_n_train_steps=args.save_steps, | |
| save_top_k=-1, | |
| save_on_train_epoch_end=False, | |
| ) | |
| lr_monitor = LearningRateMonitor(logging_interval="step") | |
| logger = WandbLogger(name=args.exp_name, log_model=False) | |
| return [checkpoint_callback, lr_monitor], logger | |
| def train(args): | |
| set_seed() | |
| # Load Config | |
| model = CausalVAEModel() | |
| if args.resume_from_checkpoint is not None: | |
| model = CausalVAEModel.from_pretrained(args.resume_from_checkpoint) | |
| else: | |
| model = CausalVAEModel.from_config(args.model_config) | |
| if (dist.is_initialized() and dist.get_rank() == 0) or not dist.is_initialized(): | |
| print(model) | |
| # Load Dataset | |
| dataset = VideoDataset(args.video_path, sequence_length=args.video_num_frames, resolution=args.resolution, sample_rate=args.sample_rate, dynamic_sample=args.dynamic_sample) | |
| train_loader = DataLoader( | |
| dataset, | |
| shuffle=True, | |
| num_workers=args.num_workers, | |
| batch_size=args.batch_size, | |
| pin_memory=True, | |
| ) | |
| # Load Callbacks and Logger | |
| callbacks, logger = load_callbacks_and_logger(args) | |
| # Load Trainer | |
| trainer = pl.Trainer( | |
| accelerator="cuda", | |
| devices=args.devices, | |
| num_nodes=args.n_nodes, | |
| callbacks=callbacks, | |
| logger=logger, | |
| log_every_n_steps=5, | |
| precision=args.precision, | |
| max_steps=args.max_steps, | |
| strategy="ddp_find_unused_parameters_true" | |
| ) | |
| trainer_kwargs = {} | |
| if args.resume_from_checkpoint: | |
| trainer_kwargs['ckpt_path'] = args.resume_from_checkpoint | |
| trainer.fit( | |
| model, | |
| train_loader, | |
| **trainer_kwargs | |
| ) | |
| if __name__ == "__main__": | |
| parser = HfArgumentParser(TrainingArguments) | |
| args = parser.parse_args_into_dataclasses() | |
| train(args[0]) | |