File size: 4,845 Bytes
1964059 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from FrameDatamodule import TyphoonDataModule
from LightningClassifModel import LightningClassifModel
from pytorch_lightning.callbacks import ModelCheckpoint
import config
from argparse import ArgumentParser
from datetime import datetime
import torch
start_time_str = str(datetime.now().strftime("%Y_%m_%d-%H.%M.%S"))
def custom_parse_args(args):
"""Argument parser, verify if model_name, device, label, size and cropped arguments are correctly initialized"""
args_parsing = ""
if args.model_name not in ["resnet18", "vgg", "vit"]:
args_parsing += "Please give model_name among resnet18, vgg or vit\n"
if args.size not in ["512", "224", 512, 224]:
args_parsing += "Please give size equals to 512 or 224\n"
if args.cropped not in ["False", "True", "false", "true", False, True]:
args_parsing += "Please give cropped equals to False or True\n"
if torch.cuda.is_available() and int(args.device) not in range(torch.cuda.device_count()):
args_parsing += "Please give a device number in the range (0, %d)\n" %torch.cuda.device_count()
if args_parsing != "":
print(args_parsing)
raise ValueError("Some arguments are not initialized correctly")
if args.size == '512' or args.size == 512:
args.size = (512,512)
elif args.size == '224' or args.size == 224:
args.size = (224, 224)
if args.cropped in ["False", "false"]: args.cropped = False
if args.cropped in ["True", "true"]: args.cropped = True
if args.device == None:
args.device = config.DEVICES
else:
args.device = [int(args.device)]
return args
def train(hparam):
"""Launch a training with the PytorchLightning workflow, the arguments given in the python command and the hyper parameters in the config file"""
hparam = custom_parse_args(hparam)
logger_name = hparam.model_name + "_" + str(hparam.size[0])
if hparam.cropped: logger_name += "_cropped"
else : logger_name += "_no-crop"
logger = TensorBoardLogger(
save_dir="results",
name= logger_name,
default_hp_metric=False,
)
# Log all hyper parameters
logger.log_hyperparams({
'start_time': start_time_str,
'LEARNING_RATE': config.LEARNING_RATE,
'BATCH_SIZE': config.BATCH_SIZE,
'NUM_WORKERS': config.NUM_WORKERS,
'MAX_EPOCHS': config.MAX_EPOCHS,
'WEIGHTS': config.WEIGHTS,
'LABEL' : "grade",
'SPLIT_BY': config.SPLIT_BY,
'LOAD_DATA': config.LOAD_DATA,
'DATASET_SPLIT': config.DATASET_SPLIT,
'STANDARDIZE_RANGE': config.STANDARDIZE_RANGE,
'DOWNSAMPLE_SIZE': hparam.size,
'CROPPED': hparam.cropped,
'NUM_CLASSES': config.NUM_CLASSES,
'ACCELERATOR': config.ACCELERATOR,
'DEVICES': hparam.device,
'DATA_DIR': config.DATA_DIR,
'MODEL_NAME': hparam.model_name,
})
# Set up dataset
data_module = TyphoonDataModule(
config.DATA_DIR,
batch_size=config.BATCH_SIZE,
num_workers=config.NUM_WORKERS,
split_by=config.SPLIT_BY,
load_data=config.LOAD_DATA,
dataset_split=config.DATASET_SPLIT,
standardize_range=config.STANDARDIZE_RANGE,
downsample_size=hparam.size,
cropped=hparam.cropped
)
# Train
classification_model = LightningClassifModel(
learning_rate=config.LEARNING_RATE,
weights=config.WEIGHTS,
num_classes=config.NUM_CLASSES,
model_name=hparam.model_name
)
# Callback for model checkpoint
checkpoint_callback = ModelCheckpoint(
dirpath= logger.save_dir + '/' + logger.name + '/version_%d/checkpoints/' % logger.version,
filename='model_{epoch}',
monitor='val_loss',
verbose=True,
every_n_epochs=1,
save_top_k = 5
)
# Setting up the lightning trainer
trainer = pl.Trainer(
logger=logger,
accelerator=config.ACCELERATOR,
devices=hparam.device,
max_epochs=config.MAX_EPOCHS,
enable_progress_bar=False,
callbacks=[checkpoint_callback]
)
# Launch training session
trainer.fit(classification_model, data_module)
return "training finished"
# Main execution block
if __name__ == "__main__":
# Parse command-line arguments using argparse
parser = ArgumentParser()
parser.add_argument("--model_name", default='vgg')
parser.add_argument("--size", default=config.DOWNSAMPLE_SIZE)
parser.add_argument("--cropped", default=True) # if size is 512, cropped argument doesn't have any impact
parser.add_argument("--device", default=config.DEVICES)
args = parser.parse_args()
print(train(args))
|