Spaces:
Runtime error
Runtime error
import sys | |
import os | |
THIS_DIR = os.path.dirname(os.path.abspath(__file__)) | |
ROOT_DIR = os.path.abspath(os.path.join(THIS_DIR, os.pardir)) | |
sys.path.append(ROOT_DIR) | |
import glob | |
import torch | |
print(torch.cuda.is_available()) | |
from training.datasets import create_dataset, create_dataloader | |
print("HIII") | |
from models import create_model | |
import pytorch_lightning as pl | |
from training.options.train_options import TrainOptions | |
from pytorch_lightning import Trainer | |
from pytorch_lightning.loggers import TensorBoardLogger | |
print("HIII") | |
from pytorch_lightning.plugins import DDPPlugin | |
from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin | |
from pytorch_lightning.callbacks import ModelCheckpoint | |
from training.utils import get_latest_checkpoint | |
from mpi4py import MPI | |
comm = MPI.COMM_WORLD | |
rank = comm.Get_rank() | |
size = comm.Get_size() | |
print(rank) | |
if __name__ == '__main__': | |
pl.seed_everything(69420) | |
opt = TrainOptions().parse() | |
# Path(opt.checkpoints_dir+"/"+opt.experiment_name).mkdir(parents=True,exist_ok=True) | |
if rank == 0: | |
if not os.path.exists(opt.checkpoints_dir+"/"+opt.experiment_name): | |
os.makedirs(opt.checkpoints_dir+"/"+opt.experiment_name) | |
print("loaded options") | |
print(opt.experiment_name) | |
model = create_model(opt) | |
print("loaded model") | |
if "tpu_cores" in vars(opt) and opt.tpu_cores is not None and opt.tpu_cores > 0: | |
plugins = None | |
elif opt.plugins is None: | |
print("DDPPlugin") | |
plugins = DDPPlugin(find_unused_parameters=opt.find_unused_parameters, num_nodes=opt.num_nodes) | |
elif opt.plugins == "deepspeed": | |
deepspeed_config = { | |
"zero_optimization": { | |
"stage": 2, | |
"cpu_offload":False, | |
}, | |
#'train_batch_size': opt.batch_size, | |
'gradient_clipping': opt.gradient_clip_val, | |
'fp16': { | |
'enabled': opt.precision == 16, | |
'loss_scale': 0, | |
'initial_scale_power': 15, | |
}, | |
} | |
plugins = DeepSpeedPlugin(config=deepspeed_config) | |
else: | |
#ddpplugin = DDPPlugin(find_unused_parameters=opt.find_unused_parameters, num_nodes=opt.num_nodes) | |
#plugins = [ddpplugin, opt.plugins] | |
plugins = opt.plugins | |
##Datasets and dataloaders | |
train_dataset = create_dataset(opt) | |
train_dataset.setup() | |
train_dataloader = create_dataloader(train_dataset) | |
if opt.do_validation: | |
val_dataset = create_dataset(opt, split="val") | |
val_dataset.setup() | |
val_dataloader = create_dataloader(val_dataset, split="val") | |
if opt.do_testing: | |
test_dataset = create_dataset(opt, split="test") | |
test_dataset.setup() | |
test_dataloader = create_dataloader(test_dataset, split="test") | |
print('#training sequences = {:d}'.format(len(train_dataset))) | |
default_save_path = opt.checkpoints_dir+"/"+opt.experiment_name | |
logger = TensorBoardLogger(opt.checkpoints_dir, name=opt.experiment_name, default_hp_metric=False) | |
checkpoint_callback = ModelCheckpoint( | |
##### | |
monitor = 'loss', | |
save_top_k = 5, | |
every_n_train_steps = 1000, | |
# every_n_train_steps = 10, | |
) | |
callbacks = [checkpoint_callback] | |
args = Trainer.parse_argparser(opt) | |
if opt.continue_train: | |
print("CONTINUE TRAIN") | |
#TODO: add option to override saved hparams when doing continue_train with an hparams file, or even make that default | |
logs_path = default_save_path | |
latest_file = get_latest_checkpoint(logs_path) | |
print(latest_file) | |
if opt.load_weights_only: | |
state_dict = torch.load(latest_file) | |
state_dict = state_dict['state_dict'] | |
load_strict = True | |
if opt.only_load_in_state_dict != "": | |
state_dict = {k:v for k,v in state_dict.items() if (opt.only_load_in_state_dict in k)} | |
load_strict = False | |
if opt.ignore_in_state_dict != "": | |
state_dict = {k:v for k,v in state_dict.items() if not (opt.ignore_in_state_dict in k)} | |
load_strict = False | |
model.load_state_dict(state_dict, strict=load_strict) | |
trainer = Trainer.from_argparse_args(args, logger=logger, default_root_dir=default_save_path, plugins=plugins, callbacks=callbacks) | |
else: | |
trainer = Trainer.from_argparse_args(args, logger=logger, default_root_dir=default_save_path, resume_from_checkpoint=latest_file, plugins=plugins, callbacks=callbacks) | |
else: | |
trainer = Trainer.from_argparse_args(args, logger=logger, default_root_dir=default_save_path, plugins=plugins, callbacks=callbacks) | |
#Tuning | |
if opt.do_tuning: | |
if opt.do_validation: | |
trainer.tune(model, train_dataloader, val_dataloader) | |
else: | |
trainer.tune(model, train_dataloader) | |
#Training | |
if not opt.skip_training: | |
if opt.do_validation: | |
trainer.fit(model, train_dataloader, val_dataloader) | |
else: | |
trainer.fit(model, train_dataloader) | |
#evaluating on test set | |
if opt.do_testing: | |
print("TESTING") | |
logs_path = default_save_path | |
latest_file = get_latest_checkpoint(logs_path) | |
print(latest_file) | |
state_dict = torch.load(latest_file) | |
model.load_state_dict(state_dict['state_dict']) | |
trainer.test(model, test_dataloader) | |
# trainer = Trainer(logger=logger) | |
# # trainer.test(model, train_dataloader) | |
# logs_path = default_save_path | |
# checkpoint_subdirs = [(d,int(d.split("_")[1])) for d in os.listdir(logs_path) if os.path.isdir(logs_path+"/"+d)] | |
# checkpoint_subdirs = sorted(checkpoint_subdirs,key=lambda t: t[1]) | |
# checkpoint_path=logs_path+"/"+checkpoint_subdirs[-1][0]+"/checkpoints/" | |
# list_of_files = glob.glob(checkpoint_path+'/*') # * means all if need specific format then *.csv | |
# latest_file = max(list_of_files, key=os.path.getctime) | |
# print(latest_file) | |
# trainer.test(model, test_dataloaders=test_dataloader, ckpt_path=latest_file) | |
# trainer.test(test_dataloaders=test_dataloader, ckpt_path=latest_file) | |
# trainer.test(test_dataloaders=test_dataloader) | |