File size: 5,194 Bytes
803ef9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from functools import partial
import argparse
from torchvision import models
import multiprocessing
from datasets import DS_LIST
from methods import METHOD_LIST


def get_cfg():
    """ generates configuration from user input in console """
    parser = argparse.ArgumentParser(description="")
    parser.add_argument(
        "--method", type=str, choices=METHOD_LIST, default="w_mse", help="loss type",
    )
    parser.add_argument(
        "--wandb",
        type=str,
        default="ssl-sota",
        help="name of the project for logging at https://wandb.ai",
    )
    parser.add_argument(
        "--byol_tau", type=float, default=0.99, help="starting tau for byol loss"
    )
    parser.add_argument(
        "--num_samples",
        type=int,
        default=2,
        help="number of samples (d) generated from each image",
    )

    addf = partial(parser.add_argument, type=float)
    addf("--cj0", default=0.4, help="color jitter brightness")
    addf("--cj1", default=0.4, help="color jitter contrast")
    addf("--cj2", default=0.4, help="color jitter saturation")
    addf("--cj3", default=0.1, help="color jitter hue")
    addf("--cj_p", default=0.8, help="color jitter probability")
    addf("--gs_p", default=0.1, help="grayscale probability")
    addf("--crop_s0", default=0.2, help="crop size from")
    addf("--crop_s1", default=1.0, help="crop size to")
    addf("--crop_r0", default=0.75, help="crop ratio from")
    addf("--crop_r1", default=(4 / 3), help="crop ratio to")
    addf("--hf_p", default=0.5, help="horizontal flip probability")

    parser.add_argument(
        "--no_lr_warmup",
        dest="lr_warmup",
        action="store_false",
        help="do not use learning rate warmup",
    )
    parser.add_argument(
        "--no_add_bn", dest="add_bn", action="store_false", help="do not use BN in head"
    )
    parser.add_argument("--knn", type=int, default=5, help="k in k-nn classifier")
    parser.add_argument("--fname", type=str, help="load model from file")
    parser.add_argument(
        "--lr_step",
        type=str,
        choices=["cos", "step", "none"],
        default="step",
        help="learning rate schedule type",
    )
    parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
    parser.add_argument(
        "--eta_min", type=float, default=0, help="min learning rate (for --lr_step cos)"
    )
    parser.add_argument(
        "--adam_l2", type=float, default=1e-6, help="weight decay (L2 penalty)"
    )
    parser.add_argument("--T0", type=int, help="period (for --lr_step cos)")
    parser.add_argument(
        "--Tmult", type=int, default=1, help="period factor (for --lr_step cos)"
    )
    parser.add_argument(
        "--w_eps", type=float, default=1e-4, help="eps for stability for whitening"
    )
    parser.add_argument(
        "--head_layers", type=int, default=2, help="number of FC layers in head"
    )
    parser.add_argument(
        "--head_size", type=int, default=1024, help="size of FC layers in head"
    )

    parser.add_argument(
        "--w_size", type=int, default=128, help="size of sub-batch for W-MSE loss"
    )
    parser.add_argument(
        "--w_iter",
        type=int,
        default=1,
        help="iterations for whitening matrix estimation",
    )

    parser.add_argument(
        "--no_norm", dest="norm", action="store_false", help="don't normalize latents",
    )
    parser.add_argument(
        "--tau", type=float, default=0.5, help="contrastive loss temperature"
    )

    parser.add_argument("--epoch", type=int, default=200, help="total epoch number")
    parser.add_argument(
        "--eval_every_drop",
        type=int,
        default=5,
        help="how often to evaluate after learning rate drop",
    )
    parser.add_argument(
        "--eval_every", type=int, default=20, help="how often to evaluate"
    )
    parser.add_argument("--emb", type=int, default=64, help="embedding size")
    parser.add_argument(
        "--bs", type=int, default=384, help="number of original images in batch N",
    )
    parser.add_argument(
        "--drop",
        type=int,
        nargs="*",
        default=[50, 25],
        help="milestones for learning rate decay (0 = last epoch)",
    )
    parser.add_argument(
        "--drop_gamma",
        type=float,
        default=0.2,
        help="multiplicative factor of learning rate decay",
    )
    parser.add_argument(
        "--arch",
        type=str,
        choices=[x for x in dir(models) if "resn" in x],
        default="resnet18",
        help="encoder architecture",
    )
    parser.add_argument("--dataset", type=str, choices=DS_LIST, default="cifar10")
    parser.add_argument(
        "--num_workers",
        type=int,
        default=0,
        help="dataset workers number",
    )
    parser.add_argument(
        "--clf",
        type=str,
        default="sgd",
        choices=["sgd", "knn", "lbfgs"],
        help="classifier for test.py",
    )
    parser.add_argument(
        "--eval_head", action="store_true", help="eval head output instead of model",
    )
    parser.add_argument("--imagenet_path", type=str, default="~/IN100/")
    return parser.parse_args()