File size: 9,795 Bytes
8c6b5ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 |
import argparse
import torch
from dassl.utils import setup_logger, set_random_seed, collect_env_info
from dassl.config import get_cfg_default
from dassl.engine import build_trainer
# custom
import datasets.oxford_pets
import datasets.oxford_flowers
import datasets.fgvc_aircraft
import datasets.dtd
import datasets.eurosat
import datasets.stanford_cars
import datasets.food101
import datasets.sun397
import datasets.caltech101
import datasets.ucf101
import datasets.imagenet
import datasets.imagenet_sketch
import datasets.imagenetv2
import datasets.imagenet_a
import datasets.imagenet_r
import trainers.coop
import trainers.cocoop
import trainers.kgcoop
import trainers.zsclip
import trainers.maple
import trainers.independentVL
import trainers.promptsrc
import trainers.tcp
import trainers.supr
import trainers.supr_ens
import trainers.elp_promptsrc
import trainers.supr_promptsrc
def print_args(args, cfg):
print("***************")
print("** Arguments **")
print("***************")
optkeys = list(args.__dict__.keys())
optkeys.sort()
for key in optkeys:
print("{}: {}".format(key, args.__dict__[key]))
print("************")
print("** Config **")
print("************")
print(cfg)
def reset_cfg(cfg, args):
if args.root:
cfg.DATASET.ROOT = args.root
if args.output_dir:
cfg.OUTPUT_DIR = args.output_dir
if args.resume:
cfg.RESUME = args.resume
if args.seed:
cfg.SEED = args.seed
if args.source_domains:
cfg.DATASET.SOURCE_DOMAINS = args.source_domains
if args.target_domains:
cfg.DATASET.TARGET_DOMAINS = args.target_domains
if args.transforms:
cfg.INPUT.TRANSFORMS = args.transforms
if args.trainer:
cfg.TRAINER.NAME = args.trainer
if args.backbone:
cfg.MODEL.BACKBONE.NAME = args.backbone
if args.head:
cfg.MODEL.HEAD.NAME = args.head
def extend_cfg(cfg):
"""
Add new config variables.
E.g.
from yacs.config import CfgNode as CN
cfg.TRAINER.MY_MODEL = CN()
cfg.TRAINER.MY_MODEL.PARAM_A = 1.
cfg.TRAINER.MY_MODEL.PARAM_B = 0.5
cfg.TRAINER.MY_MODEL.PARAM_C = False
"""
from yacs.config import CfgNode as CN
cfg.TRAINER.COOP = CN()
cfg.TRAINER.COOP.N_CTX = 16 # number of context vectors
cfg.TRAINER.COOP.CSC = False # class-specific context
cfg.TRAINER.COOP.CTX_INIT = "" # initialization words
cfg.TRAINER.COOP.PREC = "fp16" # fp16, fp32, amp
cfg.TRAINER.COOP.W = 8.0 # fp16, fp32, amp
cfg.TRAINER.COOP.CLASS_TOKEN_POSITION = "end" # 'middle' or 'end' or 'front'
cfg.TRAINER.COCOOP = CN()
cfg.TRAINER.COCOOP.N_CTX = 16 # number of context vectors
cfg.TRAINER.COCOOP.CTX_INIT = "" # initialization words
cfg.TRAINER.COCOOP.PREC = "fp16" # fp16, fp32, amp
# Config for MaPLe
cfg.TRAINER.MAPLE = CN()
cfg.TRAINER.MAPLE.N_CTX = 2 # number of context vectors
cfg.TRAINER.MAPLE.CTX_INIT = "a photo of a" # initialization words
cfg.TRAINER.MAPLE.PREC = "fp16" # fp16, fp32, amp
cfg.TRAINER.MAPLE.PROMPT_DEPTH = 9 # Max 12, minimum 0, for 1 it will act as shallow MaPLe (J=1)
cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new
# Config for PromptSRC
cfg.TRAINER.PROMPTSRC = CN()
cfg.TRAINER.PROMPTSRC.N_CTX_VISION = 4 # number of context vectors at the vision branch
cfg.TRAINER.PROMPTSRC.N_CTX_TEXT = 4 # number of context vectors at the language branch
cfg.TRAINER.PROMPTSRC.CTX_INIT = "a photo of a" # initialization words
cfg.TRAINER.PROMPTSRC.PREC = "fp16" # fp16, fp32, amp
cfg.TRAINER.PROMPTSRC.PROMPT_DEPTH_VISION = 9 # Max 12, minimum 0, for 0 it will be using shallow IVLP prompting (J=1)
cfg.TRAINER.PROMPTSRC.PROMPT_DEPTH_TEXT = 9 # Max 12, minimum 0, for 0 it will be using shallow IVLP prompting (J=1)
cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT = 25
cfg.TRAINER.PROMPTSRC.IMAGE_LOSS_WEIGHT = 10
cfg.TRAINER.PROMPTSRC.GPA_MEAN = 15
cfg.TRAINER.PROMPTSRC.GPA_STD = 1
# Config for independent Vision Language prompting (independent-vlp)
cfg.TRAINER.IVLP = CN()
cfg.TRAINER.IVLP.N_CTX_VISION = 2 # number of context vectors at the vision branch
cfg.TRAINER.IVLP.N_CTX_TEXT = 2 # number of context vectors at the language branch
cfg.TRAINER.IVLP.CTX_INIT = "a photo of a" # initialization words (only for language prompts)
cfg.TRAINER.IVLP.PREC = "fp16" # fp16, fp32, amp
# If both variables below are set to 0, 0, will the config will degenerate to COOP model
cfg.TRAINER.IVLP.PROMPT_DEPTH_VISION = 9 # Max 12, minimum 0, for 0 it will act as shallow IVLP prompting (J=1)
cfg.TRAINER.IVLP.PROMPT_DEPTH_TEXT = 9 # Max 12, minimum 0, for 0 it will act as shallow IVLP prompting(J=1)
cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new
cfg.TEST.NO_TEST = False
#For DePT
# linear classifier settings
cfg.TRAINER.LINEAR_PROBE = CN()
cfg.TRAINER.LINEAR_PROBE.TYPE = 'linear'
cfg.TRAINER.LINEAR_PROBE.WEIGHT = 0.3
cfg.TRAINER.LINEAR_PROBE.TEST_TIME_FUSION = True
# cwT module settings
cfg.TRAINER.FILM = CN()
cfg.TRAINER.FILM.LINEAR_PROBE = True
cfg.OPTIM.LR_EXP = 6.5
cfg.OPTIM.NEW_LAYERS = ['linear_probe', 'film']
#For TCP
cfg.TRAINER.TCP = CN()
cfg.TRAINER.TCP.N_CTX = 4 # number of context vectors
cfg.TRAINER.TCP.CSC = False # class-specific context
cfg.TRAINER.TCP.CTX_INIT = "" # initialization words
cfg.TRAINER.TCP.PREC = "fp16" # fp16, fp32, amp
cfg.TRAINER.TCP.W = 1.0
cfg.TRAINER.TCP.CLASS_TOKEN_POSITION = "end"
#For SuPr
cfg.TRAINER.SUPR = CN()
cfg.TRAINER.SUPR.N_CTX_VISION = 4 # number of context vectors at the vision branch
cfg.TRAINER.SUPR.N_CTX_TEXT = 4 # number of context vectors at the language branch
cfg.TRAINER.SUPR.CTX_INIT = "a photo of a" # initialization words
cfg.TRAINER.SUPR.PREC = "fp16" # fp16, fp32, amp
cfg.TRAINER.SUPR.PROMPT_DEPTH_VISION = 9 # Max 12, minimum 0, for 0 it will be using shallow IVLP prompting (J=1)
cfg.TRAINER.SUPR.PROMPT_DEPTH_TEXT = 9 # Max 12, minimum 0, for 0 it will be using shallow IVLP prompting (J=1)
cfg.TRAINER.SUPR.SPACE_DIM = 7 # Subspace dimension
cfg.TRAINER.SUPR.ENSEMBLE_NUM = 3 # For SuPr Ens
cfg.TRAINER.SUPR.REG_LOSS_WEIGHT = 60 # Regularization loss weight lambda
cfg.TRAINER.SUPR.LAMBDA = 0.7 # Balance coefficients gamma
cfg.TRAINER.SUPR.SVD = True
cfg.TRAINER.SUPR.HARD_PROMPT_PATH = "configs/trainers/SuPr/hard_prompts/"
cfg.TRAINER.SUPR.TRAINER_BACKBONE = "SuPr"
def setup_cfg(args):
cfg = get_cfg_default()
extend_cfg(cfg)
# 1. From the dataset config file
if args.dataset_config_file:
cfg.merge_from_file(args.dataset_config_file)
# 2. From the method config file
if args.config_file:
cfg.merge_from_file(args.config_file)
# 3. From input arguments
reset_cfg(cfg, args)
# 4. From optional input arguments
cfg.merge_from_list(args.opts)
cfg.freeze()
return cfg
def main(args):
cfg = setup_cfg(args)
if cfg.SEED >= 0:
print("Setting fixed seed: {}".format(cfg.SEED))
set_random_seed(cfg.SEED)
setup_logger(cfg.OUTPUT_DIR)
if torch.cuda.is_available() and cfg.USE_CUDA:
torch.backends.cudnn.benchmark = True
print_args(args, cfg)
print("Collecting env info ...")
print("** System info **\n{}\n".format(collect_env_info()))
trainer = build_trainer(cfg)
if args.eval_only:
trainer.load_model(args.model_dir, epoch=args.load_epoch)
trainer.test()
return
if not args.no_train:
trainer.train()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--root", type=str, default="", help="path to dataset")
parser.add_argument("--output-dir", type=str, default="", help="output directory")
parser.add_argument(
"--resume",
type=str,
default="",
help="checkpoint directory (from which the training resumes)",
)
parser.add_argument(
"--seed", type=int, default=-1, help="only positive value enables a fixed seed"
)
parser.add_argument(
"--source-domains", type=str, nargs="+", help="source domains for DA/DG"
)
parser.add_argument(
"--target-domains", type=str, nargs="+", help="target domains for DA/DG"
)
parser.add_argument(
"--transforms", type=str, nargs="+", help="data augmentation methods"
)
parser.add_argument(
"--config-file", type=str, default="", help="path to config file"
)
parser.add_argument(
"--dataset-config-file",
type=str,
default="",
help="path to config file for dataset setup",
)
parser.add_argument("--trainer", type=str, default="", help="name of trainer")
parser.add_argument("--backbone", type=str, default="", help="name of CNN backbone")
parser.add_argument("--head", type=str, default="", help="name of head")
parser.add_argument("--eval-only", action="store_true", help="evaluation only")
parser.add_argument(
"--model-dir",
type=str,
default="",
help="load model from this directory for eval-only mode",
)
parser.add_argument(
"--load-epoch", type=int, help="load model weights at this epoch for evaluation"
)
parser.add_argument(
"--no-train", action="store_true", help="do not call trainer.train()"
)
parser.add_argument(
"opts",
default=None,
nargs=argparse.REMAINDER,
help="modify config options using the command-line",
)
args = parser.parse_args()
main(args)
|