Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,119 Bytes
b887ad8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import sys
import os
from os.path import join as pjoin
from options.train_options import TrainOptions
from utils.plot_script import *
from models import build_models
from utils.ema import ExponentialMovingAverage
from trainers import DDPMTrainer
from motion_loader import get_dataset_loader
from accelerate.utils import set_seed
from accelerate import Accelerator
import torch
import yaml
from box import Box
def yaml_to_box(yaml_file):
with open(yaml_file, 'r') as file:
yaml_data = yaml.safe_load(file)
return Box(yaml_data)
if __name__ == '__main__':
accelerator = Accelerator()
parser = TrainOptions()
opt = parser.parse(accelerator)
set_seed(opt.seed)
torch.autograd.set_detect_anomaly(True)
opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name)
opt.model_dir = pjoin(opt.save_root, 'model')
opt.meta_dir = pjoin(opt.save_root, 'meta')
if opt.edit_mode:
edit_config = yaml_to_box('options/edit.yaml')
else:
edit_config = yaml_to_box('options/noedit.yaml')
if accelerator.is_main_process:
os.makedirs(opt.model_dir, exist_ok=True)
os.makedirs(opt.meta_dir, exist_ok=True)
train_datasetloader = get_dataset_loader(opt, batch_size = opt.batch_size, split='train', accelerator=accelerator, mode='train') # 7169
accelerator.print('\nInitializing model ...' )
encoder = build_models(opt, edit_config=edit_config)
model_ema = None
if opt.model_ema:
# Decay adjustment that aims to keep the decay independent of other hyper-parameters originally proposed at:
# https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123
adjust = 106_667 * opt.model_ema_steps / opt.num_train_steps
alpha = 1.0 - opt.model_ema_decay
alpha = min(1.0, alpha * adjust)
print('EMA alpha:',alpha)
model_ema = ExponentialMovingAverage(encoder, decay=1.0 - alpha)
accelerator.print('Finish building Model.\n')
trainer = DDPMTrainer(opt, encoder,accelerator, model_ema)
trainer.train(train_datasetloader)
|