""" Puzzle Tuning Script ver: Feb 11th 14:00 Paper: https://arxiv.org/abs/2311.06712 Code: https://github.com/sagizty/PuzzleTuning Ref: MAE https://github.com/facebookresearch/mae Step 1: PreTraining on the ImagetNet-1k style dataset (others) Step 2: Domain Prompt Tuning (PuzzleTuning) on Pathological Images (in ImageFolder) Step 3: FineTuning on the Downstream Tasks This is the training code for step 2 Pre-training Experiments: DP (data-parallel bash) python PuzzleTuning.py --batch_size 64 --blr 1.5e-4 --epochs 200 --accum_iter 2 --print_freq 2000 --check_point_gap 50 --input_size 224 --warmup_epochs 20 --pin_mem --num_workers 32 --strategy loop --PromptTuning Deep --basic_state_dict /data/saved_models/ViT_b16_224_Imagenet.pth --data_path /root/datasets/All DDP (distributed data-parallel bash) for one machine with 12 GPU python -m torch.distributed.launch --nproc_per_node=12 --nnodes 1 --node_rank 0 PuzzleTuning.py --DDP_distributed --batch_size 64 --blr 1.5e-4 --epochs 200 --accum_iter 2 --print_freq 2000 --check_point_gap 50 --input_size 224 --warmup_epochs 20 --pin_mem --num_workers 32 --strategy loop --PromptTuning Deep --basic_state_dict /data/saved_models/ViT_b16_224_Imagenet.pth --data_path /root/datasets/All update: Use "--seg_decoder" parameter to introduce segmentation networks swin_unet for Swin-Unet """ import argparse import datetime import json import numpy as np import os import time from pathlib import Path import torch import torch.backends.cudnn as cudnn from tensorboardX import SummaryWriter import torchvision.transforms as transforms import torchvision.datasets as datasets import timm # assert timm.__version__ == "0.3.2" # version check import timm.optim.optim_factory as optim_factory import SSL_structures.misc as misc from SSL_structures.misc import NativeScalerWithGradNormCount as NativeScaler from utils.schedulers import patch_scheduler, ratio_scheduler from SSL_structures import models_mae, SAE from SSL_structures.engine_pretrain import train_one_epoch def main(args): # choose encoder for timm basic_encoder = args.model[4:] # choose decoder version args.model = args.model + '_decoder' if args.seg_decoder is not None else args.model # note decoder args.model_idx = args.model_idx + args.model + '_' + args.seg_decoder if args.seg_decoder is not None \ else args.model_idx + args.model # note PromptTuning args.model_idx = args.model_idx + '_Prompt_' + args.PromptTuning + '_tokennum_' + str(args.Prompt_Token_num) \ if args.PromptTuning is not None else args.model_idx # fix the seed for reproducibility if args.DDP_distributed: misc.init_distributed_mode(args) seed = args.seed + misc.get_rank() else: seed = args.seed torch.manual_seed(seed) np.random.seed(seed) # set GPUs cudnn.benchmark = True device = torch.device(args.device) # cuda # simple augmentation transform_train = transforms.Compose([ transforms.RandomResizedCrop(args.input_size, scale=(0.8, 1.0), interpolation=3, ratio=(1. / 1., 1. / 1.)), # 3 is bicubic # transforms.Resize(args.input_size), transforms.RandomVerticalFlip(), transforms.RandomHorizontalFlip(), transforms.ToTensor(), # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) dataset_train = datasets.ImageFolder(os.path.join(args.data_path), transform=transform_train) # , 'train' print('dataset_train:', dataset_train) # Train data if args.DDP_distributed: # args.DDP_distributed is True we use distributed data parallel(DDP) num_tasks = misc.get_world_size() # use misc to set up DDP global_rank = misc.get_rank() # get the rank of the current running sampler_train = torch.utils.data.DistributedSampler( dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True) print("Sampler_train = %s" % str(sampler_train)) enable_DistributedSampler = True batch_size_for_Dataloader = args.batch_size else: # Data parallel(DP) instead of distributed data parallel(DDP) global_rank = 0 sampler_train = torch.utils.data.RandomSampler(dataset_train) enable_DistributedSampler = False batch_size_for_Dataloader = args.batch_size * torch.cuda.device_count() # set log on the main process if global_rank == 0 and args.log_dir is not None: args.log_dir = os.path.join(args.log_dir, args.model_idx) os.makedirs(args.log_dir, exist_ok=True) log_writer = SummaryWriter(log_dir=args.log_dir) # Tensorboard print('Task: ' + args.model_idx) print("Use", torch.cuda.device_count(), "GPUs!") print('job AImageFolderDir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) print("{}".format(args).replace(', ', ',\n')) else: log_writer = None # output_dir if args.output_dir is not None: args.output_dir = os.path.join(args.output_dir, args.model_idx) os.makedirs(args.output_dir, exist_ok=True) os.makedirs(os.path.join(args.output_dir, 'figs'), exist_ok=True) print('Training output files will be at', args.output_dir) else: print('no out put path specified!') raise data_loader_train = torch.utils.data.DataLoader( dataset_train, sampler=sampler_train, # the shuffle=True is already set in the sampler batch_size=batch_size_for_Dataloader, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True) # define the model if args.model[0:3] == 'mae': if args.basic_state_dict is not None: # Transfer-learning try: if args.basic_state_dict == 'timm': basic_model = timm.create_model('vit_base_patch' + str(16) + '_' + str(args.input_size), pretrained=True) basic_state_dict = basic_model.state_dict() print('MAE Transfer-learning with timm') else: basic_state_dict = torch.load(args.basic_state_dict) if 'model' in basic_state_dict: basic_state_dict = basic_state_dict['model'] except: print('erro in args.basic_state_dict:', args.basic_state_dict) if args.PromptTuning is not None: print( 'In PromptTuning, the basic_state_dict is required, without specification now, timm loaded.\n') # timm model name basic_encoder basic_model = timm.create_model(basic_encoder + '_' + str(args.input_size), pretrained=True) basic_state_dict = basic_model.state_dict() else: basic_state_dict = None print('MAE Restart with a empty backbone') else: print('MAE Transfer-learning with:', args.basic_state_dict) else: if args.PromptTuning is not None: print('In PromptTuning, the basic_state_dict is required, without specification now, timm loaded.\n') # timm model name basic_encoder basic_model = timm.create_model(basic_encoder + '_' + str(args.input_size), pretrained=True) basic_state_dict = basic_model.state_dict() else: basic_state_dict = None print('MAE Restart with a empty backbone') # mae-vit-base-patch16 model = models_mae.__dict__[args.model](img_size=args.input_size, norm_pix_loss=args.norm_pix_loss, prompt_mode=args.PromptTuning, Prompt_Token_num=args.Prompt_Token_num, basic_state_dict=basic_state_dict, dec_idx=args.seg_decoder) # setting puzzle_patch_size to not use SAE puzzle_patch_size_scheduler = None fix_position_ratio_scheduler = None # PuzzleTuning elif args.model[0:3] == 'sae': if args.basic_state_dict is not None: try: if args.basic_state_dict == 'timm': print("using timm") basic_model = timm.create_model(basic_encoder + '_' + str(args.input_size), pretrained=True) basic_state_dict = basic_model.state_dict() else: basic_state_dict = torch.load(args.basic_state_dict) except: print('erro in args.basic_state_dict:', args.basic_state_dict) if args.PromptTuning is not None: print( 'In PromptTuning, the basic_state_dict is required, without specification now, timm loaded.\n') # timm model name basic_encoder basic_model = timm.create_model(basic_encoder + '_' + str(args.input_size), pretrained=True) basic_state_dict = basic_model.state_dict() else: basic_state_dict = None print('SAE Restart with a empty backbone') else: print('Puzzle tuning with Transfer-learning:', args.basic_state_dict) else: if args.PromptTuning is not None: print('In PromptTuning, the basic_state_dict is required, without specification now, timm loaded.\n') # timm model name basic_encoder basic_model = timm.create_model(basic_encoder + '_' + str(args.input_size), pretrained=True) basic_state_dict = basic_model.state_dict() else: basic_state_dict = None print('Puzzle tuning with a empty backbone') model = SAE.__dict__[args.model](img_size=args.input_size, group_shuffle_size=args.group_shuffle_size, norm_pix_loss=args.norm_pix_loss, prompt_mode=args.PromptTuning, Prompt_Token_num=args.Prompt_Token_num, basic_state_dict=basic_state_dict, dec_idx=args.seg_decoder) fix_position_ratio_scheduler = ratio_scheduler(total_epoches=args.epochs, warmup_epochs=args.warmup_epochs, basic_ratio=0.25, # start ratio fix_position_ratio=args.fix_position_ratio, # None strategy=args.strategy) # strategy=None for fixed else reduce ratio gradually # setting puzzle_patch_size to not use MAE puzzle_patch_size_scheduler = patch_scheduler(total_epoches=args.epochs, warmup_epochs=args.warmup_epochs, edge_size=args.input_size, basic_patch=model.patch_embed.patch_size[0], fix_patch_size=args.fix_patch_size, # None strategy=args.strategy) # 'linear' # NOTICE strategy are used for setting up both the ratio-scheduler and patch-scheduler else: print('This Tuning script only support SAE(PuzzleTuning) or MAE') return -1 # the effective batch size for setting up lr if args.DDP_distributed: eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() else: eff_batch_size = args.batch_size * torch.cuda.device_count() * args.accum_iter print('eff_batch_size:', eff_batch_size) if args.lr is None: # when only base_lr is specified args.lr = args.blr * eff_batch_size / 256 print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) print("actual lr: %.2e" % args.lr) print("accumulate grad iterations: %d" % args.accum_iter) print("effective batch size: %d" % eff_batch_size) # take the model parameters for optimizer update model_without_ddp = model if args.DDP_distributed: model.cuda() # args.gpu is obtained by misc.py # find_unused_parameters=True for the DDP to correctly synchronize layers in back propagation model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) else: model = torch.nn.DataParallel(model) model.to(device) print("Model = %s" % str(model_without_ddp)) # following timm: set wd as 0 for bias and norm layers param_groups = optim_factory.add_weight_decay(model_without_ddp, args.weight_decay) optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) print(optimizer) # loss scaler with gradient clipping loss_scaler = NativeScaler(GPU_count=torch.cuda.device_count(), DDP_distributed=args.DDP_distributed) # if we have --resume,we will load the checkpoint and continue training, if not, we start a new training # the checkpoint should include model, optimizer, loss_scaler information misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) # Training by epochs print(f"Start training for {args.epochs} epochs") start_time = time.time() for epoch in range(args.start_epoch, args.epochs): # use args.start_epoch to jump to resume checkpoint if enable_DistributedSampler: # DistributedSampler need to .set_epoch(epoch) at each epoch data_loader_train.sampler.set_epoch(epoch) # training iterations train_stats = train_one_epoch(model, data_loader_train, optimizer, device, epoch, loss_scaler, fix_position_ratio_scheduler=fix_position_ratio_scheduler, puzzle_patch_size_scheduler=puzzle_patch_size_scheduler, check_samples=args.check_samples, print_freq=args.print_freq, log_writer=log_writer, args=args) if args.output_dir and (epoch % args.check_point_gap == 0 or epoch + 1 == args.epochs): misc.save_model(args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch, model_idx=args.model_idx) log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 'epoch': epoch, } # Write log if args.output_dir and misc.is_main_process(): if log_writer is not None: log_writer.flush() with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: f.write(json.dumps(log_stats) + "\n") # time stamp total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str)) def get_args_parser(): parser = argparse.ArgumentParser('SAE pre-training', add_help=False) # disable_notify parser.add_argument('--disable_notify', action='store_true', help='do not send email of tracking') # Model Name or index parser.add_argument('--model_idx', default='PuzzleTuning_', type=str, help='Model Name or index') # Original MAE(224->64) MAE A100(224->256 384->128)SAE(224->128 384->64)SAE-VPT(224->256 384->128) parser.add_argument('--batch_size', default=64, type=int, help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') parser.add_argument('--epochs', default=200, type=int) # epochs原800 parser.add_argument('--accum_iter', default=2, type=int, help='Accumulate gradient iterations ' '(for increasing the effective batch size under memory constraints)') # if we have --resume,we will load the checkpoint and continue training, if not, we start a new training # the checkpoint should include model, optimizer, loss_scaler information parser.add_argument('--resume', default='', help='resume from checkpoint') parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch of checkpoint') # Model parameters sae_vit_base_patch16 mae_vit_base_patch16 parser.add_argument('--model', default='sae_vit_base_patch16', type=str, metavar='MODEL', help='Name of model to train') # ori mae_vit_large_patch16 parser.add_argument('--seg_decoder', default=None, type=str, metavar='segmentation decoder', help='Name of segmentation decoder') parser.add_argument('--input_size', default=224, type=int,help='images input size') parser.add_argument('--model_patch_size', default=16, type=int, help='model_patch_size, default 16 for ViT-base') parser.add_argument('--num_classes', default=3, type=int, # decoder seg class set to channel help='the number of classes for segmentation') # MAE mask_ratio parser.add_argument('--mask_ratio', default=0.75, type=float, help='Masking ratio (percentage of removed patches)') # Tuning setting # PromptTuning parser.add_argument('--PromptTuning', default=None, type=str, help='use Prompt Tuning strategy (Deep/Shallow) instead of Finetuning (None, by default)') # Prompt_Token_num parser.add_argument('--Prompt_Token_num', default=20, type=int, help='Prompt_Token_num for VPT backbone') # Course learning setting parser.add_argument('--strategy', default=None, type=str, help='use linear or other puzzle size scheduler') parser.add_argument('--fix_position_ratio', default=None, type=float, help='ablation fix_position_ratio (percentage of position token patches)') parser.add_argument('--fix_patch_size', default=None, type=int, help='ablation using fix_patch_size') parser.add_argument('--group_shuffle_size', default=-1, type=int, help='group_shuffle_size of group shuffling,' 'default -1 for the whole batch as a group') # loss settings parser.add_argument('--norm_pix_loss', action='store_true', help='Use (per-patch) normalized pixels as targets for computing loss') parser.set_defaults(norm_pix_loss=False) # basic_state_dict parser.add_argument('--basic_state_dict', default=None, type=str, help='load basic backbone state_dict for Transfer-learning-based tuning, default None') # Optimizer settings parser.add_argument('--weight_decay', type=float, default=0.05, help='weight decay (default: 0.05)') parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate (absolute lr), default=None') parser.add_argument('--blr', type=float, default=1.5e-4, metavar='LR', help='base learning rate: absolute_lr = base_lr * effective batch size / 256') parser.add_argument('--min_lr', type=float, default=0., metavar='LR', help='lower lr bound for cyclic schedulers that hit 0') parser.add_argument('--warmup_epochs', type=int, default=20, metavar='N', help='epochs to warmup LR') # PATH settings # Dataset parameters /datasets01/imagenet_full_size/061417/ /data/imagenet_1k /root/autodl-tmp/imagenet parser.add_argument('--data_path', default='/root/autodl-tmp/datasets/All', type=str, help='dataset path') parser.add_argument('--output_dir', default='/root/autodl-tmp/runs', help='path where to save, empty for no saving') parser.add_argument('--log_dir', default='/root/tf-logs', help='path where to tensorboard log') parser.add_argument('--device', default='cuda', help='device to use for training / testing') parser.add_argument('--seed', default=42, type=int) # dataloader setting parser.add_argument('--num_workers', default=20, type=int) # 4A100(16,384,b128, shm40)6A100(36,384,b128, shm100) 8A100(35,384,b128, shm100) parser.add_argument('--pin_mem', action='store_true', help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') parser.set_defaults(pin_mem=True) # print_freq and checkpoint parser.add_argument('--print_freq', default=20, type=int) parser.add_argument('--check_point_gap', default=50, type=int) parser.add_argument('--check_samples', default=1, type=int, help='check how many images in a checking batch') # DDP_distributed training parameters for DDP parser.add_argument('--world_size', default=1, type=int, help='number of DDP_distributed processes') parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--dist_on_itp', action='store_true') parser.add_argument('--dist_url', default='env://', help='url used to set up DDP_distributed training') parser.add_argument('--DDP_distributed', action='store_true', help='Use DDP in training. ' 'without calling, DP with be applied') return parser if __name__ == '__main__': args = get_args_parser() args = args.parse_args() if args.output_dir: Path(args.output_dir).mkdir(parents=True, exist_ok=True) main(args)