import sys import os from os.path import join as pjoin from options.train_options import TrainOptions from utils.plot_script import * from models import build_models from utils.ema import ExponentialMovingAverage from trainers import DDPMTrainer from motion_loader import get_dataset_loader from accelerate.utils import set_seed from accelerate import Accelerator import torch import yaml from box import Box def yaml_to_box(yaml_file): with open(yaml_file, 'r') as file: yaml_data = yaml.safe_load(file) return Box(yaml_data) if __name__ == '__main__': accelerator = Accelerator() parser = TrainOptions() opt = parser.parse(accelerator) set_seed(opt.seed) torch.autograd.set_detect_anomaly(True) opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name) opt.model_dir = pjoin(opt.save_root, 'model') opt.meta_dir = pjoin(opt.save_root, 'meta') if opt.edit_mode: edit_config = yaml_to_box('options/edit.yaml') else: edit_config = yaml_to_box('options/noedit.yaml') if accelerator.is_main_process: os.makedirs(opt.model_dir, exist_ok=True) os.makedirs(opt.meta_dir, exist_ok=True) train_datasetloader = get_dataset_loader(opt, batch_size = opt.batch_size, split='train', accelerator=accelerator, mode='train') # 7169 accelerator.print('\nInitializing model ...' ) encoder = build_models(opt, edit_config=edit_config) model_ema = None if opt.model_ema: # Decay adjustment that aims to keep the decay independent of other hyper-parameters originally proposed at: # https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123 adjust = 106_667 * opt.model_ema_steps / opt.num_train_steps alpha = 1.0 - opt.model_ema_decay alpha = min(1.0, alpha * adjust) print('EMA alpha:',alpha) model_ema = ExponentialMovingAverage(encoder, decay=1.0 - alpha) accelerator.print('Finish building Model.\n') trainer = DDPMTrainer(opt, encoder,accelerator, model_ema) trainer.train(train_datasetloader)