TEDM-demo / config.py
anonymous
first commit without models
a2dba58
raw history blame
No virus
5.29 kB
import argparse
import os
from datetime import datetime
from pathlib import Path
import torch
this_dir = os.path.dirname(os.path.realpath(__file__))
default_logdir = os.path.join(this_dir, 'logs', datetime.now().strftime('%Y%m%d_%H%M%S'))
parser = argparse.ArgumentParser()
parser.add_argument('--debug', action='store_true')
parser.add_argument('--mixed_precision', type=bool, default=False, help='Use mixed precision')
parser.add_argument('--resume_path', type=str, default=None, help='Path to checkpoint to resume from')
# Experiment parameters
parser.add_argument('--experiment', type=str, default="img_only",choices=[
"PDDM",
"baseline",
"LEDM",
"LEDMe",
"TEDM",
"global_cl",
"local_cl",
"global_finetune",
"glob_loc_finetune"
], help='Whether to generate only images or images and segmentations')
parser.add_argument('--dataset', type=str, default="JSRT",choices=["JSRT", "CXR14"], help='Dataset to use')
# Data parameters
parser.add_argument('--img_size', type=int, default=128, help='Height / width of the input image to the network')
parser.add_argument('--data_dir', type=str, help='Path to the dataset')
parser.add_argument('--num_workers', type=int, default=4, help='Number of subprocesses to use for data loading')
# Model parameters
parser.add_argument('--dim', type=int, default=64, help='Width of the U-Net')
parser.add_argument('--dim_mults', nargs='+', type=int, default=(1, 2, 4, 8), help='Dimension multipliers for U-Net levels')
# SegDiff model parameters
parser.add_argument('--seg_out_dim', type=int, default=1, help='Dimension of segmentation embedding')
parser.add_argument('--img_out_dim', type=int, default=4, help='Dimension of image embedding')
parser.add_argument('--img_inter_dim', type=int, default=32, help='Width of image embedding')
# Diffusion parameters
parser.add_argument('--timesteps', type=int, default=1000, help='Number of diffusion timesteps')
parser.add_argument('--beta_schedule', type=str, default='cosine', choices=['linear', 'cosine'])
parser.add_argument('--objective', type=str, default='pred_noise', help='Model output', choices=['pred_noise', 'pred_x_0'])
# CL parameters
parser.add_argument('--tau', type=float, default=0.1, help='Temperature parameter for contrastive loss')
parser.add_argument('--global_model_path', type=str, default=None, help='Path to global model checkpoint')
parser.add_argument('--glob_loc_model_path', type=str, default=None, help='Path to global & local CL model checkpoint')
parser.add_argument('--unfreeze_weights_at_step', type=int, default=0, help='Step at which to unfreeze pretrained weights. If 0, weights are not frozen')
parser.add_argument('--augment_at_finetuning', default=False, action='store_true', help='Whether to augment images during finetuning')
# Training parameters
parser.add_argument('--batch_size', type=int, default=16, help='Input batch size')
parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate')
parser.add_argument('--weight_decay', type=float, default=0, help='Weight decay')
# parser.add_argument('--adam_betas', nargs=2, type=float, default=(0.9, 0.99), help='Betas for the Adam optimizer')
parser.add_argument('--max_steps', type=int, default=500000, help='Number of training steps to perform')
parser.add_argument('--p2_loss_weight_gamma', type=float, default=0., help='p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended')
parser.add_argument('--p2_loss_weight_k', type=float, default=1.)
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='Device to use')
parser.add_argument('--seed', type=int, default=0, help='Random seed')
# Logging parameters
parser.add_argument('--log_freq', type=int, default=100, help='Frequency of logging')
parser.add_argument('--val_freq', type=int, default=100, help='Frequency of validation')
parser.add_argument('--val_steps', type=int, default=250, help='Number of timestep to use for validation')
parser.add_argument('--log_dir', type=str, default=default_logdir, help='Logging directory')
parser.add_argument('--n_sampled_imgs', type=int, default=8, help='Number of images to sample during logging')
parser.add_argument('--max_val_steps', type=int, default=-1, help='Number of validation steps to perform')
# datasetGAN like segmentation model parameters
parser.add_argument("--saved_diffusion_model", type=str, help='Path to checkpoint of trained diffusion model', default="logs/20230127_164150/best_model.pt")
parser.add_argument("--t_steps_to_save", type=int, nargs='*', choices=range(1000), help='Diffusion steps to be used as features', default=[50, 200, 400, 600, 800])
parser.add_argument("--n_labelled_images", type=int, help='Number of labelled images to use for semi-supervised training', default=None,
choices=[197, 98, 49, 24, 12, 6, 3, 1])
# other experiments I played with
parser.add_argument("--shared_weights_over_timesteps", help='In datasetDM, only use last timestep to predict, and intermediate timesteps to train', default=False, action='store_true')
parser.add_argument("--early_stop", help='In baseline, if validation loss increases by more than 50%, stop', default=False, action='store_true')