import copy import torch def apply_overrides(params, overrides): params = copy.deepcopy(params) for param_name in overrides: if param_name not in params: print(f'override failed: no parameter named {param_name}') # raise ValueError params[param_name] = overrides[param_name] return params def get_default_params_train(overrides={}): params = {} ''' misc ''' params['device'] = 'cuda' # cuda, cpu params['save_base'] = './experiments/' params['save_frequency'] = 5 params['experiment_name'] = 'demo' params['timestamp'] = False ''' data ''' params['species_set'] = 'all' # all, snt_birds params['hard_cap_seed'] = 9472 params['hard_cap_num_per_class'] = -1 # -1 for no hard capping params['aux_species_seed'] = 8099 params['num_aux_species'] = 0 # for snt_birds case, how many other species to add in params['input_time'] = False # whether to input time as a feature params['input_time_dim'] = 0 params['dataset'] = 'inat' # inat, iucn_inat, iucn_uniform params['zero_shot'] = False params['subset_cap_name'] = None params['subset_cap_num_per_class'] = -1 # MINE - I added these -check if there is any impact params['seed'] = 1000 params['add_location_noise'] = False params['variable_context_length'] = False params['eval_dataset'] = 'eval_transformer' params['eval_num_context'] = 20 params['use_text_inputs'] = True params['use_image_inputs'] = False params['use_env_inputs'] = False params['class_token_transformation'] = 'identity' params['loc_prob'] = 1.0 params['text_prob'] = 0.0 params['image_prob'] = 0.0 params['env_prob'] = 0.0 ''' data files ''' params['obs_file'] = 'geo_prior_train.csv' params['taxa_file'] = 'geo_prior_train_meta.json' ''' model ''' params['model'] = 'ResidualFCNet' # ResidualFCNet, LinNet params['num_filts'] = 256 # embedding dimension params['input_enc'] = 'sin_cos' # sin_cos, env, sin_cos_env params['input_dim'] = 4 params['depth'] = 4 params['noise_time'] = False params['species_dim'] = 0 params['species_enc_depth'] = 0 params['species_filts'] = 256 params['species_enc'] = 'embed' params['text_emb_path'] = '' params['image_emb_path'] = '' params['text_learn_dim'] = 0 params['text_hidden_dim'] = 0 params['text_num_layers'] = 1 params['text_batchnorm'] = False params['species_dropout'] = 0.0 params['geoprior_temp'] = 0.0 # MINE - I added these params['num_context'] = 50 params['transformer_input_enc'] = 'sin_cos' params['transformer_dropout'] = 0.1 params['num_heads'] = 8 params['ema_factor'] = 0.1 params['use_register'] = True params['use_pretrained_sinr']=False params['pretrained_loc']='' params['freeze_sinr']=False params['pos_enc_class'] = 'sinr' ''' loss ''' params['loss'] = 'an_full' # an_full, an_ssdl, an_slds params['pos_weight'] = 2048 ''' optimization ''' params['batch_size'] = 2048 params['lr'] = 0.0005 params['lr_decay'] = 0.98 params['num_epochs'] = 10 ''' saving ''' params['log_frequency'] = 512 params = apply_overrides(params, overrides) return params def get_default_params_eval(overrides={}): params = {} ''' misc ''' params['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu') params['seed'] = 2022 params['exp_base'] = './experiments' params['ckp_name'] = 'model.pt' params['eval_type'] = 'snt' # snt, iucn, geo_prior, geo_feature params['experiment_name'] = 'demo' params['input_dim'] = 4 params['input_time'] = False params['input_time_dim'] = 0 # mine # params['num_samples'] = -1 # maxs params['num_samples'] = 0 params['text_section'] = '' params['extract_pos'] = False # MINE - but probably not needed anymore # params['target_background']=True ''' geo prior ''' params['batch_size'] = 2048 ''' geo feature ''' params['cell_size'] = 25 params = apply_overrides(params, overrides) return params