File size: 5,194 Bytes
803ef9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
from functools import partial
import argparse
from torchvision import models
import multiprocessing
from datasets import DS_LIST
from methods import METHOD_LIST
def get_cfg():
""" generates configuration from user input in console """
parser = argparse.ArgumentParser(description="")
parser.add_argument(
"--method", type=str, choices=METHOD_LIST, default="w_mse", help="loss type",
)
parser.add_argument(
"--wandb",
type=str,
default="ssl-sota",
help="name of the project for logging at https://wandb.ai",
)
parser.add_argument(
"--byol_tau", type=float, default=0.99, help="starting tau for byol loss"
)
parser.add_argument(
"--num_samples",
type=int,
default=2,
help="number of samples (d) generated from each image",
)
addf = partial(parser.add_argument, type=float)
addf("--cj0", default=0.4, help="color jitter brightness")
addf("--cj1", default=0.4, help="color jitter contrast")
addf("--cj2", default=0.4, help="color jitter saturation")
addf("--cj3", default=0.1, help="color jitter hue")
addf("--cj_p", default=0.8, help="color jitter probability")
addf("--gs_p", default=0.1, help="grayscale probability")
addf("--crop_s0", default=0.2, help="crop size from")
addf("--crop_s1", default=1.0, help="crop size to")
addf("--crop_r0", default=0.75, help="crop ratio from")
addf("--crop_r1", default=(4 / 3), help="crop ratio to")
addf("--hf_p", default=0.5, help="horizontal flip probability")
parser.add_argument(
"--no_lr_warmup",
dest="lr_warmup",
action="store_false",
help="do not use learning rate warmup",
)
parser.add_argument(
"--no_add_bn", dest="add_bn", action="store_false", help="do not use BN in head"
)
parser.add_argument("--knn", type=int, default=5, help="k in k-nn classifier")
parser.add_argument("--fname", type=str, help="load model from file")
parser.add_argument(
"--lr_step",
type=str,
choices=["cos", "step", "none"],
default="step",
help="learning rate schedule type",
)
parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
parser.add_argument(
"--eta_min", type=float, default=0, help="min learning rate (for --lr_step cos)"
)
parser.add_argument(
"--adam_l2", type=float, default=1e-6, help="weight decay (L2 penalty)"
)
parser.add_argument("--T0", type=int, help="period (for --lr_step cos)")
parser.add_argument(
"--Tmult", type=int, default=1, help="period factor (for --lr_step cos)"
)
parser.add_argument(
"--w_eps", type=float, default=1e-4, help="eps for stability for whitening"
)
parser.add_argument(
"--head_layers", type=int, default=2, help="number of FC layers in head"
)
parser.add_argument(
"--head_size", type=int, default=1024, help="size of FC layers in head"
)
parser.add_argument(
"--w_size", type=int, default=128, help="size of sub-batch for W-MSE loss"
)
parser.add_argument(
"--w_iter",
type=int,
default=1,
help="iterations for whitening matrix estimation",
)
parser.add_argument(
"--no_norm", dest="norm", action="store_false", help="don't normalize latents",
)
parser.add_argument(
"--tau", type=float, default=0.5, help="contrastive loss temperature"
)
parser.add_argument("--epoch", type=int, default=200, help="total epoch number")
parser.add_argument(
"--eval_every_drop",
type=int,
default=5,
help="how often to evaluate after learning rate drop",
)
parser.add_argument(
"--eval_every", type=int, default=20, help="how often to evaluate"
)
parser.add_argument("--emb", type=int, default=64, help="embedding size")
parser.add_argument(
"--bs", type=int, default=384, help="number of original images in batch N",
)
parser.add_argument(
"--drop",
type=int,
nargs="*",
default=[50, 25],
help="milestones for learning rate decay (0 = last epoch)",
)
parser.add_argument(
"--drop_gamma",
type=float,
default=0.2,
help="multiplicative factor of learning rate decay",
)
parser.add_argument(
"--arch",
type=str,
choices=[x for x in dir(models) if "resn" in x],
default="resnet18",
help="encoder architecture",
)
parser.add_argument("--dataset", type=str, choices=DS_LIST, default="cifar10")
parser.add_argument(
"--num_workers",
type=int,
default=0,
help="dataset workers number",
)
parser.add_argument(
"--clf",
type=str,
default="sgd",
choices=["sgd", "knn", "lbfgs"],
help="classifier for test.py",
)
parser.add_argument(
"--eval_head", action="store_true", help="eval head output instead of model",
)
parser.add_argument("--imagenet_path", type=str, default="~/IN100/")
return parser.parse_args()
|