DmitriiKhizbullin's picture
Reorganized files
8afb176
raw
history blame
2.51 kB
import datetime
from argparse import ArgumentParser
import torch
from lightning import Trainer
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import ModelSummary
from src.trainer import ViTLightningModule
def main():
""" Neural network trainer entry point. """
parser = ArgumentParser(description='KAUST-SDAIA Diabetic Retinopathy')
parser.add_argument('--tag', action='store', type=str,
help='Extra suffix to put on the artefact dir name')
parser.add_argument('--debug', action='store_true',
help="Dummy training cycle for testing purposes")
parser.add_argument('--convert-checkpoint', action='store', type=str,
help='Convert a checkpoint from training to pickle-independent '
'predictor-compatible directory')
args = parser.parse_args()
torch.set_float32_matmul_precision('high') # for V100/A100
if args.convert_checkpoint is not None:
print("Converting checkpoint", args.convert_checkpoint)
checkpoint = torch.load(args.convert_checkpoint, map_location="cpu")
print(list(checkpoint.keys()))
model = ViTLightningModule.load_from_checkpoint(
args.convert_checkpoint,
map_location="cpu",
hparams_file="tmp_ckpt_deleteme.yaml")
model.save_checkpoint_dk("tmp_checkp_path_deleteme")
print("Saved checkpoint. Done.")
else:
print("Start training")
fast_dev_run = True if args.debug == True else False
model = ViTLightningModule(fast_dev_run)
datetime_str = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
art_dir_name = (f"{datetime_str}" +
(f"_{args.tag}" if args.tag is not None else ""))
logger = TensorBoardLogger(save_dir=".", name="lightning_logs", version=art_dir_name)
trainer = Trainer(
logger=logger,
benchmark=True,
devices="auto",
accelerator="auto",
max_epochs=-1,
callbacks=[
ModelSummary(max_depth=-1),
],
fast_dev_run=fast_dev_run,
log_every_n_steps=10,
)
trainer.fit(
model,
train_dataloaders=model._train_dataloader,
val_dataloaders=model._val_dataloader,
)
print("Training done")
if __name__ == "__main__":
main()