Spaces:
Runtime error
Runtime error
import argparse | |
import os | |
from datetime import datetime | |
from pathlib import Path | |
import torch | |
this_dir = os.path.dirname(os.path.realpath(__file__)) | |
default_logdir = os.path.join(this_dir, 'logs', datetime.now().strftime('%Y%m%d_%H%M%S')) | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--debug', action='store_true') | |
parser.add_argument('--mixed_precision', type=bool, default=False, help='Use mixed precision') | |
parser.add_argument('--resume_path', type=str, default=None, help='Path to checkpoint to resume from') | |
# Experiment parameters | |
parser.add_argument('--experiment', type=str, default="img_only",choices=[ | |
"PDDM", | |
"baseline", | |
"LEDM", | |
"LEDMe", | |
"TEDM", | |
"global_cl", | |
"local_cl", | |
"global_finetune", | |
"glob_loc_finetune" | |
], help='Whether to generate only images or images and segmentations') | |
parser.add_argument('--dataset', type=str, default="JSRT",choices=["JSRT", "CXR14"], help='Dataset to use') | |
# Data parameters | |
parser.add_argument('--img_size', type=int, default=128, help='Height / width of the input image to the network') | |
parser.add_argument('--data_dir', type=str, help='Path to the dataset') | |
parser.add_argument('--num_workers', type=int, default=4, help='Number of subprocesses to use for data loading') | |
# Model parameters | |
parser.add_argument('--dim', type=int, default=64, help='Width of the U-Net') | |
parser.add_argument('--dim_mults', nargs='+', type=int, default=(1, 2, 4, 8), help='Dimension multipliers for U-Net levels') | |
# SegDiff model parameters | |
parser.add_argument('--seg_out_dim', type=int, default=1, help='Dimension of segmentation embedding') | |
parser.add_argument('--img_out_dim', type=int, default=4, help='Dimension of image embedding') | |
parser.add_argument('--img_inter_dim', type=int, default=32, help='Width of image embedding') | |
# Diffusion parameters | |
parser.add_argument('--timesteps', type=int, default=1000, help='Number of diffusion timesteps') | |
parser.add_argument('--beta_schedule', type=str, default='cosine', choices=['linear', 'cosine']) | |
parser.add_argument('--objective', type=str, default='pred_noise', help='Model output', choices=['pred_noise', 'pred_x_0']) | |
# CL parameters | |
parser.add_argument('--tau', type=float, default=0.1, help='Temperature parameter for contrastive loss') | |
parser.add_argument('--global_model_path', type=str, default=None, help='Path to global model checkpoint') | |
parser.add_argument('--glob_loc_model_path', type=str, default=None, help='Path to global & local CL model checkpoint') | |
parser.add_argument('--unfreeze_weights_at_step', type=int, default=0, help='Step at which to unfreeze pretrained weights. If 0, weights are not frozen') | |
parser.add_argument('--augment_at_finetuning', default=False, action='store_true', help='Whether to augment images during finetuning') | |
# Training parameters | |
parser.add_argument('--batch_size', type=int, default=16, help='Input batch size') | |
parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate') | |
parser.add_argument('--weight_decay', type=float, default=0, help='Weight decay') | |
# parser.add_argument('--adam_betas', nargs=2, type=float, default=(0.9, 0.99), help='Betas for the Adam optimizer') | |
parser.add_argument('--max_steps', type=int, default=500000, help='Number of training steps to perform') | |
parser.add_argument('--p2_loss_weight_gamma', type=float, default=0., help='p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended') | |
parser.add_argument('--p2_loss_weight_k', type=float, default=1.) | |
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='Device to use') | |
parser.add_argument('--seed', type=int, default=0, help='Random seed') | |
# Logging parameters | |
parser.add_argument('--log_freq', type=int, default=100, help='Frequency of logging') | |
parser.add_argument('--val_freq', type=int, default=100, help='Frequency of validation') | |
parser.add_argument('--val_steps', type=int, default=250, help='Number of timestep to use for validation') | |
parser.add_argument('--log_dir', type=str, default=default_logdir, help='Logging directory') | |
parser.add_argument('--n_sampled_imgs', type=int, default=8, help='Number of images to sample during logging') | |
parser.add_argument('--max_val_steps', type=int, default=-1, help='Number of validation steps to perform') | |
# datasetGAN like segmentation model parameters | |
parser.add_argument("--saved_diffusion_model", type=str, help='Path to checkpoint of trained diffusion model', default="logs/20230127_164150/best_model.pt") | |
parser.add_argument("--t_steps_to_save", type=int, nargs='*', choices=range(1000), help='Diffusion steps to be used as features', default=[50, 200, 400, 600, 800]) | |
parser.add_argument("--n_labelled_images", type=int, help='Number of labelled images to use for semi-supervised training', default=None, | |
choices=[197, 98, 49, 24, 12, 6, 3, 1]) | |
# other experiments I played with | |
parser.add_argument("--shared_weights_over_timesteps", help='In datasetDM, only use last timestep to predict, and intermediate timesteps to train', default=False, action='store_true') | |
parser.add_argument("--early_stop", help='In baseline, if validation loss increases by more than 50%, stop', default=False, action='store_true') |