import torch import argparse from modules.dataloader import R2DataLoader from modules.tokenizers import Tokenizer from modules.loss import compute_loss from modules.metrics import compute_scores from modules.optimizers import build_optimizer, build_lr_scheduler from models.models import MedCapModel from modules.trainer import Trainer import numpy as np def main(): parser = argparse.ArgumentParser() # Data input Settings parser.add_argument('--json_path', default='data/mimic_cxr/annotation.json', help='Path to the json file') parser.add_argument('--image_dir', default='data/mimic_cxr/images/', help='Directory of images') # Dataloader Settings parser.add_argument('--dataset', default='mimic_cxr', help='dataset for training MedCap') parser.add_argument('--bs', type=int, default=16) parser.add_argument('--threshold', type=int, default=10, help='the cut off frequency for the words.') parser.add_argument('--num_workers', type=int, default=2, help='the number of workers for dataloader.') parser.add_argument('--max_seq_length', type=int, default=1024, help='the maximum sequence length of the reports.') #Trainer Settings parser.add_argument('--epochs', type=int, default=30) parser.add_argument('--n_gpu', type=int, default=1, help='the number of gpus to be used.') parser.add_argument('--save_dir', type=str, default='results/mimic_cxr/', help='the patch to save the models.') parser.add_argument('--record_dir', type=str, default='./record_dir/', help='the patch to save the results of experiments.') parser.add_argument('--log_period', type=int, default=1000, help='the logging interval (in batches).') parser.add_argument('--save_period', type=int, default=1) parser.add_argument('--monitor_mode', type=str, default='max', choices=['min', 'max'], help='whether to max or min the metric.') parser.add_argument('--monitor_metric', type=str, default='BLEU_4', help='the metric to be monitored.') parser.add_argument('--early_stop', type=int, default=50, help='the patience of training.') # Training related parser.add_argument('--noise_inject', default='no', choices=['yes', 'no']) # Sample related parser.add_argument('--sample_method', type=str, default='greedy', help='the sample methods to sample a report.') parser.add_argument('--prompt', default='/prompt/prompt.pt') parser.add_argument('--prompt_load', default='no',choices=['yes','no']) # Optimization parser.add_argument('--optim', type=str, default='Adam', help='the type of the optimizer.') parser.add_argument('--lr_ve', type=float, default=1e-5, help='the learning rate for the visual extractor.') parser.add_argument('--lr_ed', type=float, default=5e-4, help='the learning rate for the remaining parameters.') parser.add_argument('--weight_decay', type=float, default=5e-5, help='the weight decay.') parser.add_argument('--adam_betas', type=tuple, default=(0.9, 0.98), help='the weight decay.') parser.add_argument('--adam_eps', type=float, default=1e-9, help='the weight decay.') parser.add_argument('--amsgrad', type=bool, default=True, help='.') parser.add_argument('--noamopt_warmup', type=int, default=5000, help='.') parser.add_argument('--noamopt_factor', type=int, default=1, help='.') # Learning Rate Scheduler parser.add_argument('--lr_scheduler', type=str, default='StepLR', help='the type of the learning rate scheduler.') parser.add_argument('--step_size', type=int, default=50, help='the step size of the learning rate scheduler.') parser.add_argument('--gamma', type=float, default=0.1, help='the gamma of the learning rate scheduler.') # Others parser.add_argument('--seed', type=int, default=9153, help='.') parser.add_argument('--resume', type=str, help='whether to resume the training from existing checkpoints.') parser.add_argument('--train_mode', default='base', choices=['base', 'fine-tuning'], help='Training mode: base (autoencoding) or fine-tuning (full supervised training or fine-tuned on downstream datasets)') parser.add_argument('--F_version', default='v1', choices=['v1', 'v2'],) parser.add_argument('--clip_update', default='no' , choices=['yes','no']) # Fine-tuning parser.add_argument('--random_init', default='yes', choices=['yes', 'no'], help='Whether to load the pre-trained weights for fine-tuning.') parser.add_argument('--weight_path', default='path_to_default_weights', type=str, help='Path to the pre-trained model weights.') args = parser.parse_args() # fix random seeds torch.manual_seed(args.seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False np.random.seed(args.seed) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # create tokenizer tokenizer = Tokenizer(args) # create data loader train_dataloader = R2DataLoader(args, tokenizer, split='train', shuffle=True) val_dataloader = R2DataLoader(args, tokenizer, split='val', shuffle=False) test_dataloader = R2DataLoader(args, tokenizer, split='test', shuffle=False) # get function handles of loss and metrics criterion = compute_loss metrics = compute_scores model = MedCapModel(args, tokenizer) if args.train_mode == 'fine-tuning' and args.random_init == 'no': # Load weights from the specified path checkpoint = torch.load(args.weight_path) model.load_state_dict(checkpoint) # build optimizer, learning rate scheduler optimizer = build_optimizer(args, model) lr_scheduler = build_lr_scheduler(args, optimizer) # build trainer and start to train trainer = Trainer(model, criterion, metrics, optimizer, args, lr_scheduler, train_dataloader, val_dataloader, test_dataloader) trainer.train() if __name__ == '__main__': main()