|
import os |
|
import sys |
|
import copy |
|
import argparse |
|
|
|
import torch |
|
from torch import optim |
|
import torch.nn as nn |
|
|
|
import mlflow.pytorch |
|
from torch.utils.data import DataLoader |
|
from torchvision.models import resnet18 |
|
import torchvision.transforms as T |
|
from pytorch_lightning.metrics.functional import accuracy |
|
import pytorch_lightning as pl |
|
from pytorch_lightning.callbacks import ModelCheckpoint |
|
|
|
from utils.base import AuxLoss, WeightedLoss, display_mlflow_run_info, l2_regularization, str2bool, fetch_from_mlflow, get_name, data_loader_mean_and_std |
|
from utils.dataset_utils import k_fold |
|
from utils.augmentation import get_augmentation |
|
from dataset import Subset, get_dataset |
|
|
|
from processing.pipeline_numpy import RawProcessingPipeline |
|
from processing.pipeline_torch import append_additive_layer, raw2rgb, RawToRGB, ParametrizedProcessing, NNProcessing |
|
|
|
from model import log_tensor, resnet_model, LitModel, TrackImagesCallback |
|
|
|
import segmentation_models_pytorch as smp |
|
|
|
from utils.ssim import SSIM |
|
|
|
|
|
parser = argparse.ArgumentParser(description='classification_task') |
|
parser.add_argument('--tracking_uri', type=str, |
|
default='http://deplo-mlflo-1ssxo94f973sj-890390d809901dbf.elb.eu-central-1.amazonaws.com', help='URI of the mlflow server on AWS') |
|
parser.add_argument('--processor_uri', type=str, default=None, |
|
help='URI of the processing model (e.g. s3://mlflow-artifacts-821771080529/1/5fa754c566e3466690b1d309a476340f/artifacts/processing-model)') |
|
parser.add_argument('--classifier_uri', type=str, default=None, |
|
help='URI of the net (e.g. s3://mlflow-artifacts-821771080529/1/5fa754c566e3466690b1d309a476340f/artifacts/prediction-model)') |
|
parser.add_argument('--state_dict_uri', type=str, |
|
default=None, help='URI of the indices you want to load (e.g. s3://mlflow-artifacts-601883093460/7/4326da05aca54107be8c554de0674a14/artifacts/training') |
|
|
|
parser.add_argument('--experiment_name', type=str, |
|
default='classification learnable pipeline', help='Specify the experiment you are running, e.g. end2end segmentation') |
|
parser.add_argument('--run_name', type=str, |
|
default='test run', help='Specify the name of your run') |
|
|
|
parser.add_argument('--log_model', type=str2bool, default=True, help='Enables model logging') |
|
parser.add_argument('--save_locally', action='store_true', |
|
help='Model will be saved locally if action is taken') |
|
|
|
parser.add_argument('--track_processing', action='store_true', |
|
help='Save images after each trasformation of the pipeline for the test set') |
|
parser.add_argument('--track_processing_gradients', action='store_true', |
|
help='Save images of gradients after each trasformation of the pipeline for the test set') |
|
parser.add_argument('--track_save_tensors', action='store_true', |
|
help='Save the torch tensors after each trasformation of the pipeline for the test set') |
|
parser.add_argument('--track_predictions', action='store_true', |
|
help='Save images after each trasformation of the pipeline for the test set + input gradient') |
|
parser.add_argument('--track_n_images', default=5, |
|
help='Track the n first elements of dataset. Only used for args.track_processing=True') |
|
parser.add_argument('--track_every_epoch', action='store_true', help='Track images every epoch or once after training') |
|
|
|
|
|
parser.add_argument('--seed', type=int, default=1, help='Global seed') |
|
parser.add_argument('--dataset', type=str, default='Microscopy', |
|
choices=['Drone', 'DroneSegmentation', 'Microscopy'], help='Select dataset') |
|
|
|
parser.add_argument('--n_splits', type=int, default=1, help='Number of splits used for training') |
|
parser.add_argument('--train_size', type=float, default=0.8, help='Fraction of training points in dataset') |
|
|
|
|
|
parser.add_argument('--lr', type=float, default=1e-5, help='learning rate used for training') |
|
parser.add_argument('--epochs', type=int, default=3, help='numper of epochs') |
|
parser.add_argument('--batch_size', type=int, default=32, help='Training batch size') |
|
parser.add_argument('--augmentation', type=str, default='none', |
|
choices=['none', 'weak', 'strong'], help='Applies augmentation to training') |
|
parser.add_argument('--check_val_every_n_epoch', type=int, default=1) |
|
|
|
|
|
parser.add_argument('--processing_mode', type=str, default='parametrized', |
|
choices=['parametrized', 'static', 'neural_network', 'none'], |
|
help='Which type of raw to rgb processing should be used') |
|
|
|
|
|
parser.add_argument('--classifier_network', type=str, default='ResNet18', choices=['ResNet18', 'ResNet34', 'Resnet50'], |
|
help='Type of pretrained network') |
|
parser.add_argument('--classifier_pretrained', action='store_true', |
|
help='Whether to use a pre-trained model or not') |
|
parser.add_argument('--smp_encoder', type=str, default='resnet34', help='segmentation models pytorch encoder') |
|
|
|
parser.add_argument('--freeze_processor', action='store_true', help='Freeze raw to rgb processing model weights') |
|
parser.add_argument('--freeze_classifier', action='store_true', help='Freeze classification model weights') |
|
|
|
|
|
parser.add_argument('--sp_debayer', type=str, default='bilinear', |
|
choices=['bilinear', 'malvar2004', 'menon2007'], help='Specify algorithm used as debayer') |
|
parser.add_argument('--sp_sharpening', type=str, default='sharpening_filter', |
|
choices=['sharpening_filter', 'unsharp_masking'], help='Specify algorithm used for sharpening') |
|
parser.add_argument('--sp_denoising', type=str, default='gaussian_denoising', |
|
choices=['gaussian_denoising', 'median_denoising', 'fft_denoising'], help='Specify algorithm used for denoising') |
|
|
|
|
|
parser.add_argument('--adv_training', action='store_true', help='Enable adversarial training') |
|
parser.add_argument('--adv_aux_weight', type=float, default=1, help='Weighting of the adversarial auxilliary loss') |
|
parser.add_argument('--adv_aux_loss', type=str, default='ssim', choices=['l2', 'ssim'], |
|
help='Type of adversarial auxilliary regularization loss') |
|
parser.add_argument('--adv_noise_layer', action='store_true', help='Adds an additive layer to Parametrized Processing') |
|
parser.add_argument('--adv_track_differences', action='store_true', help='Save difference to default pipeline') |
|
parser.add_argument('--adv_parameters', choices=['all', 'black_level', 'white_balance', |
|
'colour_correction', 'gamma_correct', 'sharpening_filter', 'gaussian_blur', 'additive_layer'], |
|
help='Target individual parameters for adversarial training.') |
|
|
|
parser.add_argument('--cache_downloaded_models', type=str2bool, default=True) |
|
|
|
parser.add_argument('--test_run', action='store_true') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
os.makedirs('results', exist_ok=True) |
|
|
|
|
|
def run_train(args): |
|
|
|
print(args) |
|
|
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
training_mode = 'adversarial' if args.adv_training else 'default' |
|
|
|
|
|
mlflow.set_tracking_uri(args.tracking_uri) |
|
mlflow.set_experiment(args.experiment_name) |
|
os.environ['AWS_ACCESS_KEY_ID'] = '#TODO: fill in your aws access key id for mlflow server here' |
|
os.environ['AWS_SECRET_ACCESS_KEY'] = '#TODO: fill in your aws secret access key for mlflow server here' |
|
|
|
dataset = get_dataset(args.dataset) |
|
|
|
print(f'dataset: {type(dataset).__name__}[{len(dataset)}]') |
|
print(f'task: {dataset.task}') |
|
print(f'mode: {training_mode} training') |
|
print(f'# cross-validation subsets: {args.n_splits}') |
|
pl.seed_everything(args.seed) |
|
idxs_kfold = k_fold(dataset, n_splits=args.n_splits, seed=args.seed, train_size=args.train_size) |
|
|
|
|
|
with mlflow.start_run(run_name=args.run_name) as parent_run: |
|
|
|
|
|
for k_iter, (train_indices, valid_indices) in enumerate(idxs_kfold): |
|
|
|
print(f'K_fold subset: {k_iter+1}/{args.n_splits}') |
|
|
|
if args.processing_mode == 'static': |
|
|
|
if args.dataset == 'Drone' or args.dataset == 'DroneSegmentation': |
|
mean = torch.tensor([0.35, 0.36, 0.35]) |
|
std = torch.tensor([0.12, 0.11, 0.12]) |
|
elif args.dataset == 'Microscopy': |
|
mean = torch.tensor([0.91, 0.84, 0.94]) |
|
std = torch.tensor([0.08, 0.12, 0.05]) |
|
|
|
|
|
dataset.transform = T.Compose([RawProcessingPipeline( |
|
camera_parameters=dataset.camera_parameters, |
|
debayer=args.sp_debayer, |
|
sharpening=args.sp_sharpening, |
|
denoising=args.sp_denoising, |
|
), |
|
T.Normalize(mean, std) |
|
]) |
|
|
|
processor = nn.Identity() |
|
|
|
|
|
if args.processor_uri is not None and args.processing_mode != 'none': |
|
print('Fetching processor: ', end='') |
|
processor = fetch_from_mlflow(args.processor_uri, type='processor', |
|
use_cache=args.cache_downloaded_models) |
|
else: |
|
print(f'processing_mode: {args.processing_mode}') |
|
normalize_mosaic = None |
|
|
|
|
|
|
|
|
|
if args.dataset == 'Microscopy': |
|
mosaic_mean = [0.5663, 0.1401, 0.0731] |
|
mosaic_std = [0.097, 0.0423, 0.008] |
|
normalize_mosaic = T.Normalize(mosaic_mean, mosaic_std) |
|
|
|
|
|
track_stages = args.track_processing or args.track_processing_gradients |
|
if args.processing_mode == 'parametrized': |
|
processor = ParametrizedProcessing( |
|
camera_parameters=dataset.camera_parameters, track_stages=track_stages, batch_norm_output=True) |
|
|
|
elif args.processing_mode == 'neural_network': |
|
processor = NNProcessing(track_stages=track_stages, |
|
normalize_mosaic=normalize_mosaic, batch_norm_output=True) |
|
elif args.processing_mode == 'none': |
|
processor = RawToRGB(reduce_size=True, out_channels=3, track_stages=track_stages, |
|
normalize_mosaic=normalize_mosaic) |
|
|
|
if args.classifier_uri: |
|
print('Fetching classifier: ', end='') |
|
classifier = fetch_from_mlflow(args.classifier_uri, type='classifier', |
|
use_cache=args.cache_downloaded_models) |
|
else: |
|
if dataset.task == 'classification': |
|
classifier = resnet_model( |
|
model=args.classifier_network, |
|
pretrained=args.classifier_pretrained, |
|
in_channels=3, |
|
fc_out_features=len(dataset.classes) |
|
) |
|
else: |
|
classifier = smp.UnetPlusPlus( |
|
encoder_name=args.smp_encoder, |
|
encoder_depth=5, |
|
encoder_weights='imagenet', |
|
in_channels=3, |
|
classes=1, |
|
activation=None, |
|
) |
|
|
|
if args.freeze_processor and len(list(iter(processor.parameters()))) == 0: |
|
print('Note: freezing processor without parameters.') |
|
assert not (args.freeze_processor and args.freeze_classifier), 'Likely no parameters to train.' |
|
|
|
if dataset.task == 'classification': |
|
loss = nn.CrossEntropyLoss() |
|
metrics = [accuracy] |
|
else: |
|
|
|
loss = smp.losses.DiceLoss(mode='binary', from_logits=True) |
|
metrics = [smp.utils.metrics.IoU()] |
|
|
|
loss_aux = None |
|
|
|
if args.adv_training: |
|
|
|
assert args.processing_mode == 'parametrized', f"Processing mode ({args.processing_mode}) should be set to 'parametrized' for adversarial training" |
|
assert args.freeze_classifier, 'Classifier should be frozen for adversarial training' |
|
assert not args.freeze_processor, 'Processor should not be frozen for adversarial training' |
|
|
|
|
|
processor_default = copy.deepcopy(processor) |
|
processor_default.track_stages = args.track_processing |
|
processor_default.eval() |
|
processor_default.to(DEVICE) |
|
|
|
for p in processor_default.parameters(): |
|
p.requires_grad = False |
|
|
|
if args.adv_noise_layer: |
|
append_additive_layer(processor) |
|
|
|
if args.adv_aux_loss == 'l2': |
|
regularization = l2_regularization |
|
elif args.adv_aux_loss == 'ssim': |
|
regularization = SSIM(window_size=11) |
|
else: |
|
NotImplementedError(args.adv_aux_loss) |
|
|
|
loss = WeightedLoss(loss=loss, weight=-1) |
|
|
|
loss_aux = AuxLoss( |
|
loss_aux=regularization, |
|
processor_adv=processor, |
|
processor_default=processor_default, |
|
weight=args.adv_aux_weight, |
|
) |
|
|
|
augmentation = get_augmentation(args.augmentation) |
|
|
|
model = LitModel( |
|
classifier=classifier, |
|
processor=processor, |
|
loss=loss, |
|
lr=args.lr, |
|
loss_aux=loss_aux, |
|
adv_training=args.adv_training, |
|
adv_parameters=args.adv_parameters, |
|
metrics=metrics, |
|
augmentation=augmentation, |
|
is_segmentation_task=dataset.task == 'segmentation', |
|
freeze_classifier=args.freeze_classifier, |
|
freeze_processor=args.freeze_processor, |
|
) |
|
|
|
state_dict = vars(args).copy() |
|
|
|
|
|
if args.state_dict_uri: |
|
state_dict = mlflow.pytorch.load_state_dict(args.state_dict_uri) |
|
train_indices = state_dict['train_indices'] |
|
valid_indices = state_dict['valid_indices'] |
|
|
|
track_indices = list(range(args.track_n_images)) |
|
|
|
if dataset.task == 'classification': |
|
state_dict['classes'] = dataset.classes |
|
state_dict['device'] = DEVICE |
|
state_dict['train_indices'] = train_indices |
|
state_dict['valid_indices'] = valid_indices |
|
state_dict['elements in train set'] = len(train_indices) |
|
state_dict['elements in test set'] = len(valid_indices) |
|
|
|
if args.test_run: |
|
train_indices = train_indices[:args.batch_size] |
|
valid_indices = valid_indices[:args.batch_size] |
|
|
|
train_set = Subset(dataset, indices=train_indices) |
|
valid_set = Subset(dataset, indices=valid_indices) |
|
track_set = Subset(dataset, indices=track_indices) |
|
|
|
train_loader = DataLoader(train_set, batch_size=args.batch_size, num_workers=16, shuffle=True) |
|
valid_loader = DataLoader(valid_set, batch_size=args.batch_size, num_workers=16, shuffle=False) |
|
track_loader = DataLoader(track_set, batch_size=args.batch_size, num_workers=16, shuffle=False) |
|
|
|
with mlflow.start_run(run_name=f"{args.run_name}_{k_iter}", nested=True) as child_run: |
|
|
|
if k_iter == 0: |
|
display_mlflow_run_info(child_run) |
|
|
|
mlflow.pytorch.log_state_dict(state_dict, artifact_path=None) |
|
|
|
hparams = { |
|
'dataset': args.dataset, |
|
'processing_mode': args.processing_mode, |
|
'training_mode': training_mode, |
|
} |
|
if training_mode == 'adversarial': |
|
hparams['adv_aux_weight'] = args.adv_aux_weight |
|
hparams['adv_aux_loss'] = args.adv_aux_loss |
|
|
|
mlflow.log_params(hparams) |
|
|
|
with open('results/state_dict.txt', 'w') as f: |
|
f.write('python ' + ' '.join(sys.argv) + '\n') |
|
f.write('\n'.join([f'{k}={v}' for k, v in state_dict.items()])) |
|
mlflow.log_artifact('results/state_dict.txt', artifact_path=None) |
|
|
|
mlf_logger = pl.loggers.MLFlowLogger(experiment_name=args.experiment_name, |
|
tracking_uri=args.tracking_uri,) |
|
mlf_logger._run_id = child_run.info.run_id |
|
|
|
reference_processor = processor_default if args.adv_training and args.adv_track_differences else None |
|
|
|
callbacks = [] |
|
if args.track_processing: |
|
callbacks += [TrackImagesCallback(track_loader, |
|
reference_processor, |
|
track_every_epoch=args.track_every_epoch, |
|
track_processing=args.track_processing, |
|
track_gradients=args.track_processing_gradients, |
|
track_predictions=args.track_predictions, |
|
save_tensors=args.track_save_tensors)] |
|
|
|
trainer = pl.Trainer( |
|
gpus=1 if DEVICE == 'cuda' else 0, |
|
min_epochs=args.epochs, |
|
max_epochs=args.epochs, |
|
logger=mlf_logger, |
|
callbacks=callbacks, |
|
check_val_every_n_epoch=args.check_val_every_n_epoch, |
|
) |
|
|
|
if args.log_model: |
|
mlflow.pytorch.autolog(log_every_n_epoch=10) |
|
print(f'model_uri="{mlflow.get_artifact_uri()}/model"') |
|
|
|
t = trainer.fit( |
|
model, |
|
train_dataloader=train_loader, |
|
val_dataloaders=valid_loader, |
|
) |
|
|
|
globals().update(locals()) |
|
|
|
return model |
|
|
|
|
|
if __name__ == '__main__': |
|
model = run_train(args) |
|
|