import sys import argparse import os.path as osp import json import torch import clip from lib import * from lib import GENFORCE_MODELS, STYLEGAN_LAYERS, SEMANTIC_DIPOLES_CORPORA from models.load_generator import load_generator def main(): """ContraCLIP -- Training script. Options: ===[ GAN Generator (G) ]======================================================================================== --gan : set pre-trained GAN generator (see GENFORCE_MODELS in lib/config.py) --stylegan-space : set StyleGAN's latent space (Z, W, W+) to look for interpretable paths TODO: add style space S --stylegan-layer : choose up to which StyleGAN's layer to use for learning latent paths E.g., if --stylegan-layer=11, then interpretable paths will be learnt in a (12 * 512)-dimensional latent space. --truncation : set W-space truncation parameter. If set, W-space codes will be truncated ===[ Corpus Support Sets (CSS) ]================================================================================ --corpus : choose corpus of prompts (see config.py/PROMPT_CORPUS). The number of elements of the tuple PROMPT_CORPUS[args.corpus] will define the number of the latent support sets; i.e., the number of warping functions -- number of the interpretable latent paths to be optimised TODO: read corpus from input file --css-beta : set beta parameter for fixing CLIP space RBFs' gamma parameters (0.25 <= css_beta < 1.0) --styleclip : use StyleCLIP approach for calculating image-text similarity ===[ Latent Support Sets (LSS) ]================================================================================ --num-latent-support-dipoles : set number of support dipoles per support set --lss-beta : set beta parameter for initializing latent space RBFs' gamma parameters (0.0 < lss_beta < 1.0) --lr : set learning rate for learning the latent support sets LSS (with Adam optimizer) --linear : use the vector connecting the poles of the dipole for calculating image-text similarity --min-shift-magnitude : set minimum latent shift magnitude --max-shift-magnitude : set maximum latent shift magnitude ===[ CLIP ]===================================================================================================== ===[ Training ]================================================================================================= --max-iter : set maximum number of training iterations --batch-size : set training batch size --loss : set loss function ('cossim', 'contrastive') --temperature : set contrastive loss temperature --log-freq : set number iterations per log --ckp-freq : set number iterations per checkpoint model saving ===[ CUDA ]===================================================================================================== --cuda : use CUDA during training (default) --no-cuda : do NOT use CUDA during training ================================================================================================================ """ parser = argparse.ArgumentParser(description="ContraCLIP training script") # === Experiment ID ============================================================================================== # parser.add_argument('--exp-id', type=str, default='', help="set optional experiment ID") # === Pre-trained GAN Generator (G) ============================================================================== # parser.add_argument('--gan', type=str, choices=GENFORCE_MODELS.keys(), help='GAN generator model') parser.add_argument('--stylegan-space', type=str, default='Z', choices=('Z', 'W', 'W+'), help="StyleGAN's latent space") parser.add_argument('--stylegan-layer', type=int, default=11, choices=range(18), help="choose up to which StyleGAN's layer to use for learning latent paths") parser.add_argument('--truncation', type=float, help="latent code sampling truncation parameter") # === Corpus Support Sets (CSS) ================================================================================== # parser.add_argument('--corpus', type=str, required=True, choices=SEMANTIC_DIPOLES_CORPORA.keys(), help="choose corpus of semantic dipoles") parser.add_argument('--css-beta', type=float, default=0.5, help="set beta parameter for initializing CLIP space RBFs' gamma parameters " "(0.25 <= css_beta < 1.0)") parser.add_argument('--styleclip', action='store_true', help="use StyleCLIP approach for calculating image-text similarity") parser.add_argument('--linear', action='store_true', help="use the vector connecting the poles of the dipole for calculating image-text similarity") # === Latent Support Sets (LSS) ================================================================================== # parser.add_argument('--num-latent-support-dipoles', type=int, help="number of latent support dipoles / support set") parser.add_argument('--lss-beta', type=float, default=0.1, help="set beta parameter for initializing latent space RBFs' gamma parameters " "(0.25 < css_beta < 1.0)") parser.add_argument('--lr', type=float, default=1e-4, help="latent support sets LSS learning rate") parser.add_argument('--min-shift-magnitude', type=float, default=0.25, help="minimum latent shift magnitude") parser.add_argument('--max-shift-magnitude', type=float, default=0.45, help="maximum latent shift magnitude") # === Training =================================================================================================== # parser.add_argument('--max-iter', type=int, default=10000, help="maximum number of training iterations") parser.add_argument('--batch-size', type=int, required=True, help="training batch size -- this should be less than " "or equal to the size of the given corpus") parser.add_argument('--loss', type=str, default='cossim', choices=('cossim', 'contrastive'), help="loss function") parser.add_argument('--temperature', type=float, default=1.0, help="contrastive temperature") parser.add_argument('--log-freq', default=10, type=int, help='number of iterations per log') parser.add_argument('--ckp-freq', default=1000, type=int, help='number of iterations per checkpoint model saving') # === CUDA ======================================================================================================= # parser.add_argument('--cuda', dest='cuda', action='store_true', help="use CUDA during training") parser.add_argument('--no-cuda', dest='cuda', action='store_false', help="do NOT use CUDA during training") parser.set_defaults(cuda=True) # ================================================================================================================ # # Parse given arguments args = parser.parse_args() # Check given batch size if args.batch_size > len(SEMANTIC_DIPOLES_CORPORA[args.corpus]): print("*** WARNING ***: Given batch size ({}) is greater than the size of the given corpus ({})\n" " Set batch size to {}".format( args.batch_size, len(SEMANTIC_DIPOLES_CORPORA[args.corpus]), len(SEMANTIC_DIPOLES_CORPORA[args.corpus]))) args.batch_size = len(SEMANTIC_DIPOLES_CORPORA[args.corpus]) # Check StyleGAN's layer if 'stylegan' in args.gan: if (args.stylegan_layer < 0) or (args.stylegan_layer > STYLEGAN_LAYERS[args.gan]-1): raise ValueError("Invalid stylegan_layer for given GAN ({}). Choose between 0 and {}".format( args.gan, STYLEGAN_LAYERS[args.gan]-1)) # Create output dir and save current arguments exp_dir = create_exp_dir(args) # CUDA use_cuda = False multi_gpu = False if torch.cuda.is_available(): if args.cuda: use_cuda = True torch.set_default_tensor_type('torch.cuda.FloatTensor') if torch.cuda.device_count() > 1: multi_gpu = True else: print("*** WARNING ***: It looks like you have a CUDA device, but aren't using CUDA.\n" " Run with --cuda for optimal training speed.") torch.set_default_tensor_type('torch.FloatTensor') else: torch.set_default_tensor_type('torch.FloatTensor') # Build GAN generator model and load with pre-trained weights print("#. Build GAN generator model G and load with pre-trained weights...") print(" \\__GAN generator : {} (res: {})".format(args.gan, GENFORCE_MODELS[args.gan][1])) print(" \\__Pre-trained weights: {}".format(GENFORCE_MODELS[args.gan][0])) G = load_generator(model_name=args.gan, latent_is_w=('stylegan' in args.gan) and ('W' in args.stylegan_space), verbose=True).eval() # Upload GAN generator model to GPU if use_cuda: G = G.cuda() # Build pretrained CLIP model print("#. Build pretrained CLIP model...") clip_model, _ = clip.load("ViT-B/32", device='cuda' if use_cuda else 'cpu', jit=False) clip_model.float() clip_model.eval() # Get CLIP (non-normalized) text features for the prompts of the given corpus prompt_f = PromptFeatures(prompt_corpus=SEMANTIC_DIPOLES_CORPORA[args.corpus], clip_model=clip_model) prompt_features = prompt_f.get_prompt_features() # Build Corpus Support Sets model CSS print("#. Build Corpus Support Sets CSS...") print(" \\__Number of corpus support sets : {}".format(prompt_f.num_prompts)) print(" \\__Number of corpus support dipoles : {}".format(1)) print(" \\__Prompt features dim : {}".format(prompt_f.prompt_features_dim)) print(" \\__Text RBF beta param : {}".format(args.css_beta)) CSS = SupportSets(prompt_features=prompt_features, css_beta=args.css_beta) # Count number of trainable parameters CSS_trainable_parameters = sum(p.numel() for p in CSS.parameters() if p.requires_grad) print(" \\__Trainable parameters: {:,}".format(CSS_trainable_parameters)) # Set support vector dimensionality and initial gamma param support_vectors_dim = G.dim_z if ('stylegan' in args.gan) and (args.stylegan_space == 'W+'): support_vectors_dim *= (args.stylegan_layer + 1) # Get Jung radii with open(osp.join('models', 'jung_radii.json'), 'r') as f: jung_radii_dict = json.load(f) if 'stylegan' in args.gan: if 'W+' in args.stylegan_space: lm = jung_radii_dict[args.gan]['W']['{}'.format(args.stylegan_layer)] elif 'W' in args.stylegan_space: lm = jung_radii_dict[args.gan]['W']['0'] else: lm = jung_radii_dict[args.gan]['Z'] jung_radius = lm[0] * args.truncation + lm[1] else: jung_radius = jung_radii_dict[args.gan]['Z'][1] # Build Latent Support Sets model LSS print("#. Build Latent Support Sets LSS...") print(" \\__Number of latent support sets : {}".format(prompt_f.num_prompts)) print(" \\__Number of latent support dipoles : {}".format(args.num_latent_support_dipoles)) print(" \\__Support Vectors dim : {}".format(support_vectors_dim)) print(" \\__Latent RBF beta param (lss-beta) : {}".format(args.lss_beta)) print(" \\__Jung radius : {}".format(jung_radius)) LSS = SupportSets(num_support_sets=prompt_f.num_prompts, num_support_dipoles=args.num_latent_support_dipoles, support_vectors_dim=support_vectors_dim, lss_beta=args.lss_beta, jung_radius=jung_radius) # Count number of trainable parameters LSS_trainable_parameters = sum(p.numel() for p in LSS.parameters() if p.requires_grad) print(" \\__Trainable parameters: {:,}".format(LSS_trainable_parameters)) # Set up trainer print("#. Experiment: {}".format(exp_dir)) t = Trainer(params=args, exp_dir=exp_dir, use_cuda=use_cuda, multi_gpu=multi_gpu) # Train t.train(generator=G, latent_support_sets=LSS, corpus_support_sets=CSS, clip_model=clip_model) if __name__ == '__main__': main()