myTest01 / training /train.py
meng2003's picture
Upload 106 files
34501b0
import sys
import os
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = os.path.abspath(os.path.join(THIS_DIR, os.pardir))
sys.path.append(ROOT_DIR)
import glob
import torch
print(torch.cuda.is_available())
from training.datasets import create_dataset, create_dataloader
print("HIII")
from models import create_model
import pytorch_lightning as pl
from training.options.train_options import TrainOptions
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
print("HIII")
from pytorch_lightning.plugins import DDPPlugin
from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin
from pytorch_lightning.callbacks import ModelCheckpoint
from training.utils import get_latest_checkpoint
from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
print(rank)
if __name__ == '__main__':
pl.seed_everything(69420)
opt = TrainOptions().parse()
# Path(opt.checkpoints_dir+"/"+opt.experiment_name).mkdir(parents=True,exist_ok=True)
if rank == 0:
if not os.path.exists(opt.checkpoints_dir+"/"+opt.experiment_name):
os.makedirs(opt.checkpoints_dir+"/"+opt.experiment_name)
print("loaded options")
print(opt.experiment_name)
model = create_model(opt)
print("loaded model")
if "tpu_cores" in vars(opt) and opt.tpu_cores is not None and opt.tpu_cores > 0:
plugins = None
elif opt.plugins is None:
print("DDPPlugin")
plugins = DDPPlugin(find_unused_parameters=opt.find_unused_parameters, num_nodes=opt.num_nodes)
elif opt.plugins == "deepspeed":
deepspeed_config = {
"zero_optimization": {
"stage": 2,
"cpu_offload":False,
},
#'train_batch_size': opt.batch_size,
'gradient_clipping': opt.gradient_clip_val,
'fp16': {
'enabled': opt.precision == 16,
'loss_scale': 0,
'initial_scale_power': 15,
},
}
plugins = DeepSpeedPlugin(config=deepspeed_config)
else:
#ddpplugin = DDPPlugin(find_unused_parameters=opt.find_unused_parameters, num_nodes=opt.num_nodes)
#plugins = [ddpplugin, opt.plugins]
plugins = opt.plugins
##Datasets and dataloaders
train_dataset = create_dataset(opt)
train_dataset.setup()
train_dataloader = create_dataloader(train_dataset)
if opt.do_validation:
val_dataset = create_dataset(opt, split="val")
val_dataset.setup()
val_dataloader = create_dataloader(val_dataset, split="val")
if opt.do_testing:
test_dataset = create_dataset(opt, split="test")
test_dataset.setup()
test_dataloader = create_dataloader(test_dataset, split="test")
print('#training sequences = {:d}'.format(len(train_dataset)))
default_save_path = opt.checkpoints_dir+"/"+opt.experiment_name
logger = TensorBoardLogger(opt.checkpoints_dir, name=opt.experiment_name, default_hp_metric=False)
checkpoint_callback = ModelCheckpoint(
#####
monitor = 'loss',
save_top_k = 5,
every_n_train_steps = 1000,
# every_n_train_steps = 10,
)
callbacks = [checkpoint_callback]
args = Trainer.parse_argparser(opt)
if opt.continue_train:
print("CONTINUE TRAIN")
#TODO: add option to override saved hparams when doing continue_train with an hparams file, or even make that default
logs_path = default_save_path
latest_file = get_latest_checkpoint(logs_path)
print(latest_file)
if opt.load_weights_only:
state_dict = torch.load(latest_file)
state_dict = state_dict['state_dict']
load_strict = True
if opt.only_load_in_state_dict != "":
state_dict = {k:v for k,v in state_dict.items() if (opt.only_load_in_state_dict in k)}
load_strict = False
if opt.ignore_in_state_dict != "":
state_dict = {k:v for k,v in state_dict.items() if not (opt.ignore_in_state_dict in k)}
load_strict = False
model.load_state_dict(state_dict, strict=load_strict)
trainer = Trainer.from_argparse_args(args, logger=logger, default_root_dir=default_save_path, plugins=plugins, callbacks=callbacks)
else:
trainer = Trainer.from_argparse_args(args, logger=logger, default_root_dir=default_save_path, resume_from_checkpoint=latest_file, plugins=plugins, callbacks=callbacks)
else:
trainer = Trainer.from_argparse_args(args, logger=logger, default_root_dir=default_save_path, plugins=plugins, callbacks=callbacks)
#Tuning
if opt.do_tuning:
if opt.do_validation:
trainer.tune(model, train_dataloader, val_dataloader)
else:
trainer.tune(model, train_dataloader)
#Training
if not opt.skip_training:
if opt.do_validation:
trainer.fit(model, train_dataloader, val_dataloader)
else:
trainer.fit(model, train_dataloader)
#evaluating on test set
if opt.do_testing:
print("TESTING")
logs_path = default_save_path
latest_file = get_latest_checkpoint(logs_path)
print(latest_file)
state_dict = torch.load(latest_file)
model.load_state_dict(state_dict['state_dict'])
trainer.test(model, test_dataloader)
# trainer = Trainer(logger=logger)
# # trainer.test(model, train_dataloader)
# logs_path = default_save_path
# checkpoint_subdirs = [(d,int(d.split("_")[1])) for d in os.listdir(logs_path) if os.path.isdir(logs_path+"/"+d)]
# checkpoint_subdirs = sorted(checkpoint_subdirs,key=lambda t: t[1])
# checkpoint_path=logs_path+"/"+checkpoint_subdirs[-1][0]+"/checkpoints/"
# list_of_files = glob.glob(checkpoint_path+'/*') # * means all if need specific format then *.csv
# latest_file = max(list_of_files, key=os.path.getctime)
# print(latest_file)
# trainer.test(model, test_dataloaders=test_dataloader, ckpt_path=latest_file)
# trainer.test(test_dataloaders=test_dataloader, ckpt_path=latest_file)
# trainer.test(test_dataloaders=test_dataloader)