File size: 1,393 Bytes
4738a88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
from pathlib import Path
import torch
import pickle
import argparse
import logging
import torch.distributed as dist
from config import MyParser
from steps import trainer
if __name__ == "__main__":
formatter = (
"%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d || %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
torch.cuda.empty_cache()
args = MyParser().parse_args()
logging.info(args)
exp_dir = Path(args.exp_dir)
exp_dir.mkdir(exist_ok=True, parents=True)
logging.info(f"exp_dir: {str(exp_dir)}")
if args.resume:
resume = args.resume
assert(bool(args.exp_dir))
with open("%s/args.pkl" % args.exp_dir, "rb") as f:
old_args = pickle.load(f)
new_args = vars(args)
old_args = vars(old_args)
for key in new_args:
if key not in old_args or old_args[key] != new_args[key]:
old_args[key] = new_args[key]
args = argparse.Namespace(**old_args)
args.resume = resume
else:
with open("%s/args.pkl" % args.exp_dir, "wb") as f:
pickle.dump(args, f)
dist.init_process_group(backend='nccl', init_method='env://')
rank = dist.get_rank()
world_size = dist.get_world_size()
torch.cuda.set_device(rank)
my_trainer = trainer.Trainer(args, world_size, rank)
my_trainer.train() |