File size: 13,342 Bytes
8c212a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import sys
import argparse
import os.path as osp
import json
import torch
import clip
from lib import *
from lib import GENFORCE_MODELS, STYLEGAN_LAYERS, SEMANTIC_DIPOLES_CORPORA
from models.load_generator import load_generator


def main():
    """ContraCLIP -- Training script.

    Options:
        ===[ GAN Generator (G) ]========================================================================================
        --gan                        : set pre-trained GAN generator (see GENFORCE_MODELS in lib/config.py)
        --stylegan-space             : set StyleGAN's latent space (Z, W, W+) to look for interpretable paths
                                       TODO: add style space S
        --stylegan-layer             : choose up to which StyleGAN's layer to use for learning latent paths
                                       E.g., if --stylegan-layer=11, then interpretable paths will be learnt in a
                                       (12 * 512)-dimensional latent space.
        --truncation                 : set W-space truncation parameter. If set, W-space codes will be truncated

        ===[ Corpus Support Sets (CSS) ]================================================================================
        --corpus                     : choose corpus of prompts (see config.py/PROMPT_CORPUS). The number of elements of
                                       the tuple PROMPT_CORPUS[args.corpus] will define the number of the latent support
                                       sets; i.e., the number of warping functions -- number of the interpretable latent
                                       paths to be optimised
                                       TODO: read corpus from input file
        --css-beta                   : set beta parameter for fixing CLIP space RBFs' gamma parameters
                                       (0.25 <= css_beta < 1.0)
        --styleclip                  : use StyleCLIP approach for calculating image-text similarity

        ===[ Latent Support Sets (LSS) ]================================================================================
        --num-latent-support-dipoles : set number of support dipoles per support set
        --lss-beta                   : set beta parameter for initializing latent space RBFs' gamma parameters
                                       (0.0 < lss_beta < 1.0)
        --lr                         : set learning rate for learning the latent support sets LSS (with Adam optimizer)
        --linear                     : use the vector connecting the poles of the dipole for calculating image-text
                                       similarity
        --min-shift-magnitude        : set minimum latent shift magnitude
        --max-shift-magnitude        : set maximum latent shift magnitude

        ===[ CLIP ]=====================================================================================================


        ===[ Training ]=================================================================================================
        --max-iter                   : set maximum number of training iterations
        --batch-size                 : set training batch size
        --loss                       : set loss function ('cossim', 'contrastive')
        --temperature                : set contrastive loss temperature
        --log-freq                   : set number iterations per log
        --ckp-freq                   : set number iterations per checkpoint model saving

        ===[ CUDA ]=====================================================================================================
        --cuda                           : use CUDA during training (default)
        --no-cuda                        : do NOT use CUDA during training
        ================================================================================================================
    """
    parser = argparse.ArgumentParser(description="ContraCLIP training script")

    # === Experiment ID ============================================================================================== #
    parser.add_argument('--exp-id', type=str, default='', help="set optional experiment ID")

    # === Pre-trained GAN Generator (G) ============================================================================== #
    parser.add_argument('--gan', type=str, choices=GENFORCE_MODELS.keys(), help='GAN generator model')
    parser.add_argument('--stylegan-space', type=str, default='Z', choices=('Z', 'W', 'W+'),
                        help="StyleGAN's latent space")
    parser.add_argument('--stylegan-layer', type=int, default=11, choices=range(18),
                        help="choose up to which StyleGAN's layer to use for learning latent paths")
    parser.add_argument('--truncation', type=float, help="latent code sampling truncation parameter")

    # === Corpus Support Sets (CSS) ================================================================================== #
    parser.add_argument('--corpus', type=str, required=True, choices=SEMANTIC_DIPOLES_CORPORA.keys(),
                        help="choose corpus of semantic dipoles")
    parser.add_argument('--css-beta', type=float, default=0.5,
                        help="set beta parameter for initializing CLIP space RBFs' gamma parameters "
                             "(0.25 <= css_beta < 1.0)")
    parser.add_argument('--styleclip', action='store_true',
                        help="use StyleCLIP approach for calculating image-text similarity")
    parser.add_argument('--linear', action='store_true',
                        help="use the vector connecting the poles of the dipole for calculating image-text similarity")

    # === Latent Support Sets (LSS) ================================================================================== #
    parser.add_argument('--num-latent-support-dipoles', type=int, help="number of latent support dipoles / support set")
    parser.add_argument('--lss-beta', type=float, default=0.1,
                        help="set beta parameter for initializing latent space RBFs' gamma parameters "
                             "(0.25 < css_beta < 1.0)")
    parser.add_argument('--lr', type=float, default=1e-4, help="latent support sets LSS learning rate")
    parser.add_argument('--min-shift-magnitude', type=float, default=0.25, help="minimum latent shift magnitude")
    parser.add_argument('--max-shift-magnitude', type=float, default=0.45, help="maximum latent shift magnitude")

    # === Training =================================================================================================== #
    parser.add_argument('--max-iter', type=int, default=10000, help="maximum number of training iterations")
    parser.add_argument('--batch-size', type=int, required=True, help="training batch size -- this should be less than "
                                                                      "or equal to the size of the given corpus")
    parser.add_argument('--loss', type=str, default='cossim', choices=('cossim', 'contrastive'),
                        help="loss function")
    parser.add_argument('--temperature', type=float, default=1.0, help="contrastive temperature")
    parser.add_argument('--log-freq', default=10, type=int, help='number of iterations per log')
    parser.add_argument('--ckp-freq', default=1000, type=int, help='number of iterations per checkpoint model saving')

    # === CUDA ======================================================================================================= #
    parser.add_argument('--cuda', dest='cuda', action='store_true', help="use CUDA during training")
    parser.add_argument('--no-cuda', dest='cuda', action='store_false', help="do NOT use CUDA during training")
    parser.set_defaults(cuda=True)
    # ================================================================================================================ #

    # Parse given arguments
    args = parser.parse_args()

    # Check given batch size
    if args.batch_size > len(SEMANTIC_DIPOLES_CORPORA[args.corpus]):
        print("*** WARNING ***: Given batch size ({}) is greater than the size of the given corpus ({})\n"
              "                 Set batch size to {}".format(
                args.batch_size, len(SEMANTIC_DIPOLES_CORPORA[args.corpus]),
                len(SEMANTIC_DIPOLES_CORPORA[args.corpus])))
        args.batch_size = len(SEMANTIC_DIPOLES_CORPORA[args.corpus])

    # Check StyleGAN's layer
    if 'stylegan' in args.gan:
        if (args.stylegan_layer < 0) or (args.stylegan_layer > STYLEGAN_LAYERS[args.gan]-1):
            raise ValueError("Invalid stylegan_layer for given GAN ({}). Choose between 0 and {}".format(
                args.gan, STYLEGAN_LAYERS[args.gan]-1))

    # Create output dir and save current arguments
    exp_dir = create_exp_dir(args)

    # CUDA
    use_cuda = False
    multi_gpu = False
    if torch.cuda.is_available():
        if args.cuda:
            use_cuda = True
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
            if torch.cuda.device_count() > 1:
                multi_gpu = True
        else:
            print("*** WARNING ***: It looks like you have a CUDA device, but aren't using CUDA.\n"
                  "                 Run with --cuda for optimal training speed.")
            torch.set_default_tensor_type('torch.FloatTensor')
    else:
        torch.set_default_tensor_type('torch.FloatTensor')

    # Build GAN generator model and load with pre-trained weights
    print("#. Build GAN generator model G and load with pre-trained weights...")
    print("  \\__GAN generator : {} (res: {})".format(args.gan, GENFORCE_MODELS[args.gan][1]))
    print("  \\__Pre-trained weights: {}".format(GENFORCE_MODELS[args.gan][0]))
    G = load_generator(model_name=args.gan,
                       latent_is_w=('stylegan' in args.gan) and ('W' in args.stylegan_space),
                       verbose=True).eval()

    # Upload GAN generator model to GPU
    if use_cuda:
        G = G.cuda()

    # Build pretrained CLIP model
    print("#. Build pretrained CLIP model...")
    clip_model, _ = clip.load("ViT-B/32", device='cuda' if use_cuda else 'cpu', jit=False)
    clip_model.float()
    clip_model.eval()

    # Get CLIP (non-normalized) text features for the prompts of the given corpus
    prompt_f = PromptFeatures(prompt_corpus=SEMANTIC_DIPOLES_CORPORA[args.corpus], clip_model=clip_model)
    prompt_features = prompt_f.get_prompt_features()

    # Build Corpus Support Sets model CSS
    print("#. Build Corpus Support Sets CSS...")
    print("  \\__Number of corpus support sets    : {}".format(prompt_f.num_prompts))
    print("  \\__Number of corpus support dipoles : {}".format(1))
    print("  \\__Prompt features dim              : {}".format(prompt_f.prompt_features_dim))
    print("  \\__Text RBF beta param              : {}".format(args.css_beta))

    CSS = SupportSets(prompt_features=prompt_features, css_beta=args.css_beta)

    # Count number of trainable parameters
    CSS_trainable_parameters = sum(p.numel() for p in CSS.parameters() if p.requires_grad)
    print("  \\__Trainable parameters: {:,}".format(CSS_trainable_parameters))

    # Set support vector dimensionality and initial gamma param
    support_vectors_dim = G.dim_z
    if ('stylegan' in args.gan) and (args.stylegan_space == 'W+'):
        support_vectors_dim *= (args.stylegan_layer + 1)

    # Get Jung radii
    with open(osp.join('models', 'jung_radii.json'), 'r') as f:
        jung_radii_dict = json.load(f)

    if 'stylegan' in args.gan:
        if 'W+' in args.stylegan_space:
            lm = jung_radii_dict[args.gan]['W']['{}'.format(args.stylegan_layer)]
        elif 'W' in args.stylegan_space:
            lm = jung_radii_dict[args.gan]['W']['0']
        else:
            lm = jung_radii_dict[args.gan]['Z']
        jung_radius = lm[0] * args.truncation + lm[1]
    else:
        jung_radius = jung_radii_dict[args.gan]['Z'][1]

    # Build Latent Support Sets model LSS
    print("#. Build Latent Support Sets LSS...")
    print("  \\__Number of latent support sets    : {}".format(prompt_f.num_prompts))
    print("  \\__Number of latent support dipoles : {}".format(args.num_latent_support_dipoles))
    print("  \\__Support Vectors dim              : {}".format(support_vectors_dim))
    print("  \\__Latent RBF beta param (lss-beta) : {}".format(args.lss_beta))
    print("  \\__Jung radius                      : {}".format(jung_radius))

    LSS = SupportSets(num_support_sets=prompt_f.num_prompts,
                      num_support_dipoles=args.num_latent_support_dipoles,
                      support_vectors_dim=support_vectors_dim,
                      lss_beta=args.lss_beta,
                      jung_radius=jung_radius)

    # Count number of trainable parameters
    LSS_trainable_parameters = sum(p.numel() for p in LSS.parameters() if p.requires_grad)
    print("  \\__Trainable parameters: {:,}".format(LSS_trainable_parameters))
    
    # Set up trainer
    print("#. Experiment: {}".format(exp_dir))
    t = Trainer(params=args, exp_dir=exp_dir, use_cuda=use_cuda, multi_gpu=multi_gpu)

    # Train
    t.train(generator=G, latent_support_sets=LSS, corpus_support_sets=CSS, clip_model=clip_model)


if __name__ == '__main__':
    main()