Spaces:
Running
on
Zero
Running
on
Zero
| ############################################################################################ | |
| # Name: test_parameter.py | |
| # | |
| # NOTE: Change all your hyper-params here! | |
| # Simple How-To Guide: | |
| # 1. CLIP TTA: USE_CLIP_PREDS = True, EXECUTE_TTA = True | |
| # 2. CLIP (No TTA): USE_CLIP_PREDS = True, EXECUTE_TTA = False | |
| # 3. Custom masks (e.g. LLMSeg): USE_CLIP_PREDS = False, EXECUTE_TTA = False | |
| ############################################################################################ | |
| import os | |
| import sys | |
| sys.modules['TRAINING'] = False # False = Inference Testing | |
| ############################################################### | |
| OPT_VARS = {} | |
| def getenv(var_name, default=None, cast_type=str): | |
| try: | |
| value = os.environ.get(var_name, None) | |
| if value is None: | |
| result = default | |
| elif cast_type == bool: | |
| result = value.lower() in ("true", "1", "yes") | |
| else: | |
| result = cast_type(value) | |
| except (ValueError, TypeError): | |
| result = default | |
| OPT_VARS[var_name] = result # Log the result | |
| return result | |
| ############################################################### | |
| POLICY = getenv("POLICY", default="RL", cast_type=str) | |
| # TAX_HIERARCHY_TO_CONDENSE = 3 # Remove N layers of the taxonomy hierarchy from the back | |
| NUM_TEST = 800 # Overriden if TAXABIND_TTA is True and performing search ds val | |
| NUM_RUN = 1 | |
| SAVE_GIFS = getenv("SAVE_GIFS", default=True, cast_type=bool) # do you want to save GIFs | |
| SAVE_TRAJECTORY = False # do you want to save per-step metrics | |
| SAVE_LENGTH = False # do you want to save per-episode metrics | |
| VIZ_GRAPH_EDGES = False # do you want to visualize the graph edges | |
| # MODEL_NAME = "pure_coverage_no_pose_obs_230325_stage1.pth" # checkpoint.pth | |
| # MODEL_NAME = "STAGE2_20k_vlm_search_24x24_290225_NO_TARGET_REWARDS_600steps.pth" # checkpoint.pth | |
| # MODEL_NAME = "vlm_search_24x24_230225_NO_TARGET_REWARDS_600steps.pth" # checkpoint.pth | |
| # MODEL_NAME = "vlm_search_20x20_200125_256steps_CORRECT_REWARDS.pth" # checkpoint.pth | |
| MODEL_NAME = "STAGE1_vlm_search_24x24_040425_no_tgt_rewards_iNAT_DS_16k.pth" | |
| NUM_EPS_STEPS = getenv("NUM_EPS_STEPS", default=384, cast_type=int) | |
| TERMINATE_ON_TGTS_FOUND = True # Whether to terminate episode when all targets found | |
| FORCE_LOGGING_DONE_TGTS_FOUND = True # Whether to force csv logging when all targets found | |
| FIX_START_POSITION = getenv("FIX_START_POSITION", default=True, cast_type=bool) # Whether to fix the starting position of the robots (middle index) | |
| ## Whether to override initial score mask from CLIP | |
| USE_CLIP_PREDS = getenv("USE_CLIP_PREDS", default=True, cast_type=bool) # If false, use custom masks from OVERRIDE_MASK_DIR | |
| OVERRIDE_MASK_DIR = getenv("OVERRIDE_MASK_DIR", default="/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/Baselines_Masks/LLaVA_GroundedSAM/v1_010425/LLaVA_GroundedSAM_out_mask_val_in/out_mask_val_in", cast_type=str) | |
| # OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/pred/Animalia_Chordata_Mammalia_Rodentia" | |
| # OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/pred/Animalia_Chordata_Mammalia_Artiodactyla" | |
| # OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/pred/Animalia_Arthropoda_Arachnida_Araneae" | |
| # OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/pred/Plantae_Tracheophyta_Magnoliopsida_Caryophyllales" | |
| # OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/Baselines_Masks/LLaVA_GroundedSAM/v1_010425/LLaVA_GroundedSAM_out_mask_val_in/out_mask_val_in" | |
| # OVERRIDE_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/Baselines_Masks/LLaVA_GroundedSAM/v1_010425/LLaVA_GroundedSAM_out_mask_val_out/out_mask_val_out" | |
| # Used to calcultae info_gain metric | |
| OVERRIDE_GT_MASK_DIR = getenv("OVERRIDE_GT_MASK_DIR", default="", cast_type=str) | |
| # OVERRIDE_GT_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/gt/val_in_4gsnet_score_map" | |
| # OVERRIDE_GT_MASK_DIR = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/score_maps/v1_200325/gt/val_out_4gsnet_score_map" | |
| ####################################################################### | |
| # iNAT TTA | |
| ####################################################################### | |
| # Query Params | |
| QUERY_TAX = getenv("QUERY_TAX", default="", cast_type=str) # "" = Test all tax | |
| # QUERY_TAX = "Animalia Chordata Mammalia Rodentia" # search_val_in | |
| # QUERY_TAX = "Animalia Chordata Mammalia Artiodactyla" # search_val_in | |
| # QUERY_TAX = "Animalia Arthropoda Arachnida Araneae" # search_val_out | |
| # QUERY_TAX = "Plantae Tracheophyta Magnoliopsida Caryophyllales" # search_val_out | |
| # TTA PARAMS | |
| EXECUTE_TTA = getenv("EXECUTE_TTA", default=True, cast_type=bool) # Whether to execute TTA mask updates | |
| STEPS_PER_TTA = 20 # no. steps before each TTA series | |
| NUM_TTA_STEPS = 1 # no. of TTA steps during each series | |
| INITIAL_MODALITY = getenv("INITIAL_MODALITY", default="image", cast_type=str) # "image", "text", "combined" | |
| MODALITY = getenv("MODALITY", default="image", cast_type=str) # "image", "text", "combined" | |
| QUERY_VARIETY = getenv("QUERY_VARIETY", default=False, cast_type=bool) # "image", "text", "combined" | |
| RESET_WEIGHTS = True | |
| MIN_LR = 1e-6 | |
| MAX_LR = 1e-5 # 1e-5 | |
| GAMMA_EXPONENT = 2 # 2 | |
| # Paths related to taxabind (TRAIN w/ TARGETS) | |
| TAXABIND_TTA = True # Whether to init TTA classes - FOR NOW: Always True | |
| TAXABIND_IMG_DIR = '/mnt/hdd/inat2021_ds/inat21' | |
| TAXABIND_IMO_DIR = '/mnt/hdd/inat2021_ds/sat_train_jpg_512px' # sat_test_jpg_256px, sat_test_jpg_512px | |
| TAXABIND_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/inat21_train.json' # no filter needed | |
| TAXABIND_FILTERED_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/2_OTHERS/inat21_filtered_pixel_clip_train.json' # no filter needed | |
| # TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = '/mnt/hdd/inat2021_ds/filtered_mapping_sat_to_img_ids_val.json' | |
| TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = getenv("TAXABIND_SAT_TO_IMG_IDS_JSON_PATH", default="/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/FINAL_COUNTS_COMBINED_top4000_pos_top55000_neg/search_val_in.json", cast_type=str) | |
| # TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/search_val_in.json" | |
| # TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = "/mnt/hdd/inat2021_ds/target_search_ds/NEW_3-20_Targets/COUNTS_COMBINED_top2000_pos_top22000_neg/search_val_out.json" | |
| TAXABIND_PATCH_SIZE=14 | |
| TAXABIND_SAT_CHECKPOINT_PATH="/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px_search_ds_filtered/pixel_clip_512px_search_ds_100625_CLIP-L-336_FINAL_SPLIT_LARGE_BUGFIX_CLIP_TRAIN_CORRECT_VAL_IN_TAX_FILTER_TGT_ONLY/satbind-epoch=02-val_loss=2.50_BACKUP.ckpt" # "/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px_search_ds_filtered/pixel_clip_512px_search_ds_070425_CLIP-L-336_FINAL_SPLIT_LARGE/satbind-epoch=02-val_loss=2.48-BACKUP.ckpt" | |
| TAXABIND_GAUSSIAN_BLUR_KERNEL = (5,5) | |
| TAXABIND_SAMPLE_INDEX = 8 # DEBUG (Starting point) 5, 6, 8 | |
| # Sound | |
| TAXABIND_SOUND_DATA_PATH = '/mnt/hdd/inat2021_ds/2_OTHERS/sound_test' | |
| TAXABIND_SOUND_CHECKPOINT_PATH = "/home/user/Taxabind/TaxaBind/SoundBind/checkpoints/BUGFIX_CLIP_TRAIN_CORRECT_without_out_domain_taxs_v4_220625/soundbind-epoch=19-val_loss=3.92_BACKUP.ckpt" | |
| # # Paths related to taxabind (TRAIN w/ TARGETS) | |
| # TAXABIND_TTA = True | |
| # TAXABIND_IMG_DIR = '/mnt/hdd/inat2021_ds/inat21' | |
| # TAXABIND_IMO_DIR = '/mnt/hdd/inat2021_ds/sat_train_jpg_512px' # sat_test_jpg_256px, sat_test_jpg_512px | |
| # TAXABIND_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/inat21_train.json' # no filter needed | |
| # TAXABIND_FILTERED_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/2_OTHERS/inat21_filtered_pixel_clip_train.json' # no filter needed | |
| # # TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = '/mnt/hdd/inat2021_ds/filtered_mapping_sat_to_img_ids_val.json' | |
| # TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = "/mnt/hdd/inat2021_ds/target_search_ds/OLD/taxon_sat_target_search_100x_per_10-20counts.json" | |
| # TAXABIND_PATCH_SIZE=14 | |
| # TAXABIND_SAT_CHECKPOINT_PATH="/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_190225_CLIP-L-336/satbind-epoch=02-val_loss=2.40.ckpt" # /home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_160225_NO_DATASET_SHUFFLE/satbind-epoch=02-val_loss=2.26-BACKUP.ckpt | |
| # TAXABIND_SAMPLE_INDEX = 99 # (Starting point) 99,141 | |
| # # Paths related to taxabind (VAL) | |
| # TAXABIND_TTA = True | |
| # TAXABIND_IMG_DIR = '/mnt/hdd/inat2021_ds/inat21' | |
| # TAXABIND_IMO_DIR = '/mnt/hdd/inat2021_ds/sat_test_jpg_512px' # sat_test_jpg_256px, sat_test_jpg_512px | |
| # TAXABIND_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/inat21_val.json' # no filter needed | |
| # TAXABIND_FILTERED_INAT_JSON_PATH = '/mnt/hdd/inat2021_ds/inat21_filtered_pixel_clip_val.json' # no filter needed | |
| # # TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = '/mnt/hdd/inat2021_ds/filtered_mapping_sat_to_img_ids_val.json' | |
| # TAXABIND_SAT_TO_IMG_IDS_JSON_PATH = "/mnt/hdd/inat2021_ds/target_search_ds/taxon_sat_target_search_100x_per_10-20counts.json" | |
| # TAXABIND_PATCH_SIZE=14 | |
| # TAXABIND_SAT_CHECKPOINT_PATH="/home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_190225_CLIP-L-336/satbind-epoch=02-val_loss=2.40.ckpt" # /home/user/Taxabind/TaxaBind/SatBind/checkpoints/BACKUP/512px/pixel_clip_512px_160225_NO_DATASET_SHUFFLE/satbind-epoch=02-val_loss=2.26-BACKUP.ckpt | |
| # TAXABIND_SAMPLE_INDEX = 45 # TEMP | |
| ####################################################################### | |
| # Pretraining | |
| ####################################################################### | |
| # TODO: Get rid of the LISA stuff... | |
| # If LISA trained clss | |
| GRIDMAP_SET_DIR = "Maps/flair_real_maps/envs_val_trained_clss" | |
| MASK_SET_DIR = "Maps/flair_real_maps_lisa_pred/finetuned_LISA_v3_original_losses/flair_lisa_soft_masks_trained_clss_v3" # original_LISA, finetuned_LISA_v3_original_losses | |
| TARGETS_SET_DIR = "Maps/flair_real_maps/masks_val_trained_clss" # If empty, then targets assumed to be on MASK_SET_DIR | |
| RAW_IMG_PATH_DICT = "Maps/flair_real_maps/flair-ds-paths-filtered-with-scores-val-trained-clss.csv" # flair-ds-paths-filtered-with-scores-train.csv, flair-ds-paths-filtered-with-scores-val-trained-clss.csv, flair-ds-paths-filtered-with-scores-val-out-clss.csv | |
| # # If LISA out clss | |
| # GRIDMAP_SET_DIR = "Maps/flair_real_maps/envs_val_out_clss" | |
| # MASK_SET_DIR = "Maps/flair_real_maps_lisa_pred/finetuned_LISA_v3_original_losses/flair_lisa_soft_masks_out_clss_v3" # original_LISA, finetuned_LISA_v3_original_losses | |
| # TARGETS_SET_DIR = "Maps/flair_real_maps/masks_val_out_clss" # If empty, then targets assumed to be on MASK_SET_DIR | |
| # RAW_IMG_PATH_DICT = "Maps/flair_real_maps/flair-ds-paths-filtered-with-scores-val-out-clss.csv" # flair-ds-paths-filtered-with-scores-train.csv, flair-ds-paths-filtered-with-scores-val-trained-clss.csv, flair-ds-paths-filtered-with-scores-val-out-clss.csv | |
| ####################################################################### | |
| NUM_ROBOTS = 1 | |
| NUM_COORDS_WIDTH=24 # How many node coords across width? | |
| NUM_COORDS_HEIGHT=24 # How many node coords across height? | |
| HIGH_INFO_REWARD_RATIO = 0.75 # Ratio of rewards for moving to uncertain area (high info vs low info) | |
| SENSOR_RANGE=80 # Only applicable to 'circle' sensor model | |
| SENSOR_MODEL="rectangular" # "rectangular", "circle" (NOTE: (no colllision check for rectangular) | |
| INPUT_DIM = 4 | |
| EMBEDDING_DIM = 128 | |
| K_SIZE = 8 # 8 | |
| USE_GPU = False # do you want to use GPUS? | |
| NUM_GPU = getenv("NUM_GPU", default=2, cast_type=int) # the number of GPUs | |
| NUM_META_AGENT = getenv("NUM_META_AGENT", default=4, cast_type=int) # the number of processes | |
| FOLDER_NAME = 'inference' | |
| model_path = f'{FOLDER_NAME}/model' | |
| gifs_path = f'{FOLDER_NAME}/test_results/gifs' | |
| trajectory_path = f'{FOLDER_NAME}/test_results/trajectory' | |
| length_path = f'{FOLDER_NAME}/test_results/length' | |
| log_path = f'{FOLDER_NAME}/test_results/log' | |
| CSV_EXPT_NAME = getenv("CSV_EXPT_NAME", default="data", cast_type=str) | |
| # trajectory_path = f'results/trajectory' | |
| # length_path = f'results/length' | |
| # COLORS (for printing) | |
| RED='\033[1;31m' | |
| GREEN='\033[1;32m' | |
| YELLOW='\033[1;93m' | |
| NC_BOLD='\033[1m' # Bold, No Color | |
| NC='\033[0m' # No Color | |