|
import argparse |
|
import torch |
|
|
|
from dassl.utils import setup_logger, set_random_seed, collect_env_info |
|
from dassl.config import get_cfg_default |
|
from dassl.engine import build_trainer |
|
|
|
|
|
import datasets.oxford_pets |
|
import datasets.oxford_flowers |
|
import datasets.fgvc_aircraft |
|
import datasets.dtd |
|
import datasets.eurosat |
|
import datasets.stanford_cars |
|
import datasets.food101 |
|
import datasets.sun397 |
|
import datasets.caltech101 |
|
import datasets.ucf101 |
|
import datasets.imagenet |
|
|
|
import datasets.imagenet_sketch |
|
import datasets.imagenetv2 |
|
import datasets.imagenet_a |
|
import datasets.imagenet_r |
|
|
|
import trainers.coop |
|
import trainers.cocoop |
|
import trainers.kgcoop |
|
import trainers.zsclip |
|
import trainers.maple |
|
import trainers.independentVL |
|
import trainers.promptsrc |
|
import trainers.tcp |
|
import trainers.supr |
|
import trainers.supr_ens |
|
import trainers.elp_promptsrc |
|
import trainers.supr_promptsrc |
|
|
|
|
|
def print_args(args, cfg): |
|
print("***************") |
|
print("** Arguments **") |
|
print("***************") |
|
optkeys = list(args.__dict__.keys()) |
|
optkeys.sort() |
|
for key in optkeys: |
|
print("{}: {}".format(key, args.__dict__[key])) |
|
print("************") |
|
print("** Config **") |
|
print("************") |
|
print(cfg) |
|
|
|
|
|
def reset_cfg(cfg, args): |
|
if args.root: |
|
cfg.DATASET.ROOT = args.root |
|
|
|
if args.output_dir: |
|
cfg.OUTPUT_DIR = args.output_dir |
|
|
|
if args.resume: |
|
cfg.RESUME = args.resume |
|
|
|
if args.seed: |
|
cfg.SEED = args.seed |
|
|
|
if args.source_domains: |
|
cfg.DATASET.SOURCE_DOMAINS = args.source_domains |
|
|
|
if args.target_domains: |
|
cfg.DATASET.TARGET_DOMAINS = args.target_domains |
|
|
|
if args.transforms: |
|
cfg.INPUT.TRANSFORMS = args.transforms |
|
|
|
if args.trainer: |
|
cfg.TRAINER.NAME = args.trainer |
|
|
|
if args.backbone: |
|
cfg.MODEL.BACKBONE.NAME = args.backbone |
|
|
|
if args.head: |
|
cfg.MODEL.HEAD.NAME = args.head |
|
|
|
|
|
def extend_cfg(cfg): |
|
""" |
|
Add new config variables. |
|
|
|
E.g. |
|
from yacs.config import CfgNode as CN |
|
cfg.TRAINER.MY_MODEL = CN() |
|
cfg.TRAINER.MY_MODEL.PARAM_A = 1. |
|
cfg.TRAINER.MY_MODEL.PARAM_B = 0.5 |
|
cfg.TRAINER.MY_MODEL.PARAM_C = False |
|
""" |
|
from yacs.config import CfgNode as CN |
|
|
|
cfg.TRAINER.COOP = CN() |
|
cfg.TRAINER.COOP.N_CTX = 16 |
|
cfg.TRAINER.COOP.CSC = False |
|
cfg.TRAINER.COOP.CTX_INIT = "" |
|
cfg.TRAINER.COOP.PREC = "fp16" |
|
cfg.TRAINER.COOP.W = 8.0 |
|
cfg.TRAINER.COOP.CLASS_TOKEN_POSITION = "end" |
|
|
|
cfg.TRAINER.COCOOP = CN() |
|
cfg.TRAINER.COCOOP.N_CTX = 16 |
|
cfg.TRAINER.COCOOP.CTX_INIT = "" |
|
cfg.TRAINER.COCOOP.PREC = "fp16" |
|
|
|
|
|
cfg.TRAINER.MAPLE = CN() |
|
cfg.TRAINER.MAPLE.N_CTX = 2 |
|
cfg.TRAINER.MAPLE.CTX_INIT = "a photo of a" |
|
cfg.TRAINER.MAPLE.PREC = "fp16" |
|
cfg.TRAINER.MAPLE.PROMPT_DEPTH = 9 |
|
cfg.DATASET.SUBSAMPLE_CLASSES = "all" |
|
|
|
|
|
cfg.TRAINER.PROMPTSRC = CN() |
|
cfg.TRAINER.PROMPTSRC.N_CTX_VISION = 4 |
|
cfg.TRAINER.PROMPTSRC.N_CTX_TEXT = 4 |
|
cfg.TRAINER.PROMPTSRC.CTX_INIT = "a photo of a" |
|
cfg.TRAINER.PROMPTSRC.PREC = "fp16" |
|
cfg.TRAINER.PROMPTSRC.PROMPT_DEPTH_VISION = 9 |
|
cfg.TRAINER.PROMPTSRC.PROMPT_DEPTH_TEXT = 9 |
|
cfg.TRAINER.PROMPTSRC.TEXT_LOSS_WEIGHT = 25 |
|
cfg.TRAINER.PROMPTSRC.IMAGE_LOSS_WEIGHT = 10 |
|
cfg.TRAINER.PROMPTSRC.GPA_MEAN = 15 |
|
cfg.TRAINER.PROMPTSRC.GPA_STD = 1 |
|
|
|
|
|
|
|
cfg.TRAINER.IVLP = CN() |
|
cfg.TRAINER.IVLP.N_CTX_VISION = 2 |
|
cfg.TRAINER.IVLP.N_CTX_TEXT = 2 |
|
cfg.TRAINER.IVLP.CTX_INIT = "a photo of a" |
|
cfg.TRAINER.IVLP.PREC = "fp16" |
|
|
|
cfg.TRAINER.IVLP.PROMPT_DEPTH_VISION = 9 |
|
cfg.TRAINER.IVLP.PROMPT_DEPTH_TEXT = 9 |
|
cfg.DATASET.SUBSAMPLE_CLASSES = "all" |
|
cfg.TEST.NO_TEST = False |
|
|
|
|
|
|
|
|
|
cfg.TRAINER.LINEAR_PROBE = CN() |
|
cfg.TRAINER.LINEAR_PROBE.TYPE = 'linear' |
|
cfg.TRAINER.LINEAR_PROBE.WEIGHT = 0.3 |
|
cfg.TRAINER.LINEAR_PROBE.TEST_TIME_FUSION = True |
|
|
|
|
|
cfg.TRAINER.FILM = CN() |
|
cfg.TRAINER.FILM.LINEAR_PROBE = True |
|
cfg.OPTIM.LR_EXP = 6.5 |
|
cfg.OPTIM.NEW_LAYERS = ['linear_probe', 'film'] |
|
|
|
|
|
cfg.TRAINER.TCP = CN() |
|
cfg.TRAINER.TCP.N_CTX = 4 |
|
cfg.TRAINER.TCP.CSC = False |
|
cfg.TRAINER.TCP.CTX_INIT = "" |
|
cfg.TRAINER.TCP.PREC = "fp16" |
|
cfg.TRAINER.TCP.W = 1.0 |
|
cfg.TRAINER.TCP.CLASS_TOKEN_POSITION = "end" |
|
|
|
|
|
|
|
cfg.TRAINER.SUPR = CN() |
|
cfg.TRAINER.SUPR.N_CTX_VISION = 4 |
|
cfg.TRAINER.SUPR.N_CTX_TEXT = 4 |
|
cfg.TRAINER.SUPR.CTX_INIT = "a photo of a" |
|
cfg.TRAINER.SUPR.PREC = "fp16" |
|
cfg.TRAINER.SUPR.PROMPT_DEPTH_VISION = 9 |
|
cfg.TRAINER.SUPR.PROMPT_DEPTH_TEXT = 9 |
|
cfg.TRAINER.SUPR.SPACE_DIM = 7 |
|
cfg.TRAINER.SUPR.ENSEMBLE_NUM = 3 |
|
cfg.TRAINER.SUPR.REG_LOSS_WEIGHT = 60 |
|
cfg.TRAINER.SUPR.LAMBDA = 0.7 |
|
cfg.TRAINER.SUPR.SVD = True |
|
cfg.TRAINER.SUPR.HARD_PROMPT_PATH = "configs/trainers/SuPr/hard_prompts/" |
|
cfg.TRAINER.SUPR.TRAINER_BACKBONE = "SuPr" |
|
|
|
|
|
def setup_cfg(args): |
|
cfg = get_cfg_default() |
|
extend_cfg(cfg) |
|
|
|
|
|
if args.dataset_config_file: |
|
cfg.merge_from_file(args.dataset_config_file) |
|
|
|
|
|
if args.config_file: |
|
cfg.merge_from_file(args.config_file) |
|
|
|
|
|
reset_cfg(cfg, args) |
|
|
|
|
|
cfg.merge_from_list(args.opts) |
|
|
|
cfg.freeze() |
|
|
|
return cfg |
|
|
|
|
|
def main(args): |
|
cfg = setup_cfg(args) |
|
if cfg.SEED >= 0: |
|
print("Setting fixed seed: {}".format(cfg.SEED)) |
|
set_random_seed(cfg.SEED) |
|
setup_logger(cfg.OUTPUT_DIR) |
|
|
|
if torch.cuda.is_available() and cfg.USE_CUDA: |
|
torch.backends.cudnn.benchmark = True |
|
|
|
print_args(args, cfg) |
|
print("Collecting env info ...") |
|
print("** System info **\n{}\n".format(collect_env_info())) |
|
|
|
trainer = build_trainer(cfg) |
|
|
|
if args.eval_only: |
|
trainer.load_model(args.model_dir, epoch=args.load_epoch) |
|
trainer.test() |
|
return |
|
|
|
|
|
if not args.no_train: |
|
trainer.train() |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--root", type=str, default="", help="path to dataset") |
|
parser.add_argument("--output-dir", type=str, default="", help="output directory") |
|
parser.add_argument( |
|
"--resume", |
|
type=str, |
|
default="", |
|
help="checkpoint directory (from which the training resumes)", |
|
) |
|
parser.add_argument( |
|
"--seed", type=int, default=-1, help="only positive value enables a fixed seed" |
|
) |
|
parser.add_argument( |
|
"--source-domains", type=str, nargs="+", help="source domains for DA/DG" |
|
) |
|
parser.add_argument( |
|
"--target-domains", type=str, nargs="+", help="target domains for DA/DG" |
|
) |
|
parser.add_argument( |
|
"--transforms", type=str, nargs="+", help="data augmentation methods" |
|
) |
|
parser.add_argument( |
|
"--config-file", type=str, default="", help="path to config file" |
|
) |
|
parser.add_argument( |
|
"--dataset-config-file", |
|
type=str, |
|
default="", |
|
help="path to config file for dataset setup", |
|
) |
|
parser.add_argument("--trainer", type=str, default="", help="name of trainer") |
|
parser.add_argument("--backbone", type=str, default="", help="name of CNN backbone") |
|
parser.add_argument("--head", type=str, default="", help="name of head") |
|
parser.add_argument("--eval-only", action="store_true", help="evaluation only") |
|
parser.add_argument( |
|
"--model-dir", |
|
type=str, |
|
default="", |
|
help="load model from this directory for eval-only mode", |
|
) |
|
parser.add_argument( |
|
"--load-epoch", type=int, help="load model weights at this epoch for evaluation" |
|
) |
|
parser.add_argument( |
|
"--no-train", action="store_true", help="do not call trainer.train()" |
|
) |
|
parser.add_argument( |
|
"opts", |
|
default=None, |
|
nargs=argparse.REMAINDER, |
|
help="modify config options using the command-line", |
|
) |
|
args = parser.parse_args() |
|
main(args) |
|
|