diff --git a/README.md b/README.md index e51a12b..a6e1ca1 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,21 @@ conda activate jat pip install -e .[dev] ``` +## REGENT fork of sample-factory: Installation +Following [this install ink](https://www.samplefactory.dev/01-get-started/installation/) but for the fork: +```shell +git clone https://github.com/kaustubhsridhar/sample-factory.git +cd sample-factory +pip install -e .[dev,mujoco,atari,envpool,vizdoom] +``` + +# Regent fork of sample-factory: Train Unseen Env Policies and Generate Datasets +Train policies using envpool's atari: +```shell +bash scripts_sample-factory/train_unseen_atari.sh +``` +Note that the training command inside the above script was obtained from the config files of Ed Beeching's Atari 57 models on Huggingface. An example is [here](https://huggingface.co/edbeeching/atari_2B_atari_mspacman_1111/blob/main/cfg.json#L124). See my discussion [here](https://huggingface.co/edbeeching/atari_2B_atari_mspacman_1111/discussions/2). + ## PREV Installation To get started with JAT, follow these steps: @@ -155,12 +170,21 @@ python -u scripts_jat_regent/eval_RandP.py --task ${TASK} &> outputs/RandP/${TAS ``` ### REGENT Analyze data +Necessary: ```shell -python -u examples_regent/compare_datasets.py &> examples_regent/compare_datasets.txt & - python -u examples_regent/analyze_rows_tokenized.py &> examples_regent/analyze_rows_tokenized.txt & +``` +Already ran and output dict in code: +```shell python -u examples_regent/get_dim_all_vector_tasks.py &> examples_regent/get_dim_all_vector_tasks.txt & + +python -u examples_regent/count_rows_to_consider.py &> examples_regent/count_rows_to_consider.txt & +``` + +Optional: +```shell +python -u examples_regent/compare_datasets.py &> examples_regent/compare_datasets.txt & ``` ## PREV Dataset diff --git a/jat_regent/RandP.py b/jat_regent/RandP.py deleted file mode 100644 index b2bd8bf..0000000 --- a/jat_regent/RandP.py +++ /dev/null @@ -1,38 +0,0 @@ -import warnings -from dataclasses import dataclass -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch -import torch.nn.functional as F -from gymnasium import spaces -from torch import BoolTensor, FloatTensor, LongTensor, Tensor, nn -from transformers import GPTNeoModel, GPTNeoPreTrainedModel -from transformers.modeling_outputs import ModelOutput -from transformers.models.vit.modeling_vit import ViTPatchEmbeddings - -from jat.configuration_jat import JatConfig -from jat.processing_jat import JatProcessor - - -class RandP(): - def __init__(self, dataset) -> None: - self.steps = 0 - # create an index for retrieval in vector obs envs (OR) collect all images in Atari - - def reset_rl(self): - self.steps = 0 - - def get_next_action( - self, - processor: JatProcessor, - continuous_observation: Optional[List[float]] = None, - discrete_observation: Optional[List[int]] = None, - text_observation: Optional[str] = None, - image_observation: Optional[np.ndarray] = None, - action_space: Union[spaces.Box, spaces.Discrete] = None, - reward: Optional[float] = None, - deterministic: bool = False, - context_window: Optional[int] = None, - ): - pass \ No newline at end of file diff --git a/jat_regent/modelling_jat_regent.py b/jat_regent/modelling_jat_regent.py deleted file mode 100644 index e69de29..0000000 diff --git a/jat_regent/utils.py b/jat_regent/utils.py index 56bfb44..36f6cca 100644 --- a/jat_regent/utils.py +++ b/jat_regent/utils.py @@ -8,23 +8,35 @@ from tqdm import tqdm from autofaiss import build_index +UNSEEN_TASK_NAMES = { # Total -- atari: 57, metaworld: 50, babyai: 39, mujoco: 11 + +} + def myprint(str): - # check if first character of string is a newline character - if str[0] == '\n': - str_without_newline = str[1:] + # check if first characters of string are newline character + num_newlines = 0 + while str[num_newlines] == '\n': print() - else: - str_without_newline = str + num_newlines += 1 + str_without_newline = str[num_newlines:] print(f'{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}: {str_without_newline}') def is_png_img(item): return isinstance(item, PngImagePlugin.PngImageFile) +def get_last_row_for_1M_states(task): + last_row_idx = {'atari-alien': 14134, 'atari-amidar': 14319, 'atari-assault': 14427, 'atari-asterix': 14456, 'atari-asteroids': 14348, 'atari-atlantis': 14325, 'atari-bankheist': 14167, 'atari-battlezone': 13981, 'atari-beamrider': 13442, 'atari-berzerk': 13534, 'atari-bowling': 14110, 'atari-boxing': 14542, 'atari-breakout': 13474, 'atari-centipede': 14196, 'atari-choppercommand': 13397, 'atari-crazyclimber': 14026, 'atari-defender': 13504, 'atari-demonattack': 13499, 'atari-doubledunk': 14292, 'atari-enduro': 13260, 'atari-fishingderby': 14073, 'atari-freeway': 14016, 'atari-frostbite': 14075, 'atari-gopher': 13143, 'atari-gravitar': 14405, 'atari-hero': 14044, 'atari-icehockey': 14017, 'atari-jamesbond': 12678, 'atari-kangaroo': 14248, 'atari-krull': 14204, 'atari-kungfumaster': 14030, 'atari-montezumarevenge': 14219, 'atari-mspacman': 14120, 'atari-namethisgame': 13575, 'atari-phoenix': 13539, 'atari-pitfall': 14287, 'atari-pong': 14151, 'atari-privateeye': 14105, 'atari-qbert': 14026, 'atari-riverraid': 14275, 'atari-roadrunner': 14127, 'atari-robotank': 14079, 'atari-seaquest': 14097, 'atari-skiing': 14708, 'atari-solaris': 14199, 'atari-spaceinvaders': 12652, 'atari-stargunner': 13822, 'atari-surround': 13840, 'atari-tennis': 14062, 'atari-timepilot': 13896, 'atari-tutankham': 13121, 'atari-upndown': 13504, 'atari-venture': 14260, 'atari-videopinball': 14272, 'atari-wizardofwor': 13920, 'atari-yarsrevenge': 13981, 'atari-zaxxon': 13833, 'babyai-action-obj-door': 95000, 'babyai-blocked-unlock-pickup': 29279, 'babyai-boss-level-no-unlock': 12087, 'babyai-boss-level': 12101, 'babyai-find-obj-s5': 32974, 'babyai-go-to-door': 95000, 'babyai-go-to-imp-unlock': 9286, 'babyai-go-to-local': 95000, 'babyai-go-to-obj-door': 95000, 'babyai-go-to-obj': 95000, 'babyai-go-to-red-ball-grey': 95000, 'babyai-go-to-red-ball-no-dists': 95000, 'babyai-go-to-red-ball': 95000, 'babyai-go-to-red-blue-ball': 95000, 'babyai-go-to-seq': 13744, 'babyai-go-to': 18974, 'babyai-key-corridor': 9014, 'babyai-mini-boss-level': 38119, 'babyai-move-two-across-s8n9': 24505, 'babyai-one-room-s8': 95000, 'babyai-open-door': 95000, 'babyai-open-doors-order-n4': 95000, 'babyai-open-red-door': 95000, 'babyai-open-two-doors': 73291, 'babyai-open': 32559, 'babyai-pickup-above': 34084, 'babyai-pickup-dist': 89640, 'babyai-pickup-loc': 95000, 'babyai-pickup': 18670, 'babyai-put-next-local': 83187, 'babyai-put-next': 56986, 'babyai-synth-loc': 21605, 'babyai-synth-seq': 13049, 'babyai-synth': 19409, 'babyai-unblock-pickup': 17881, 'babyai-unlock-local': 71186, 'babyai-unlock-pickup': 50883, 'babyai-unlock-to-unlock': 23062, 'babyai-unlock': 11734, 'metaworld-assembly': 10000, 'metaworld-basketball': 10000, 'metaworld-bin-picking': 10000, 'metaworld-box-close': 10000, 'metaworld-button-press-topdown-wall': 10000, 'metaworld-button-press-topdown': 10000, 'metaworld-button-press-wall': 10000, 'metaworld-button-press': 10000, 'metaworld-coffee-button': 10000, 'metaworld-coffee-pull': 10000, 'metaworld-coffee-push': 10000, 'metaworld-dial-turn': 10000, 'metaworld-disassemble': 10000, 'metaworld-door-close': 10000, 'metaworld-door-lock': 10000, 'metaworld-door-open': 10000, 'metaworld-door-unlock': 10000, 'metaworld-drawer-close': 10000, 'metaworld-drawer-open': 10000, 'metaworld-faucet-close': 10000, 'metaworld-faucet-open': 10000, 'metaworld-hammer': 10000, 'metaworld-hand-insert': 10000, 'metaworld-handle-press-side': 10000, 'metaworld-handle-press': 10000, 'metaworld-handle-pull-side': 10000, 'metaworld-handle-pull': 10000, 'metaworld-lever-pull': 10000, 'metaworld-peg-insert-side': 10000, 'metaworld-peg-unplug-side': 10000, 'metaworld-pick-out-of-hole': 10000, 'metaworld-pick-place-wall': 10000, 'metaworld-pick-place': 10000, 'metaworld-plate-slide-back-side': 10000, 'metaworld-plate-slide-back': 10000, 'metaworld-plate-slide-side': 10000, 'metaworld-plate-slide': 10000, 'metaworld-push-back': 10000, 'metaworld-push-wall': 10000, 'metaworld-push': 10000, 'metaworld-reach-wall': 10000, 'metaworld-reach': 10000, 'metaworld-shelf-place': 10000, 'metaworld-soccer': 10000, 'metaworld-stick-pull': 10000, 'metaworld-stick-push': 10000, 'metaworld-sweep-into': 10000, 'metaworld-sweep': 10000, 'metaworld-window-close': 10000, 'metaworld-window-open': 10000, 'mujoco-ant': 4023, 'mujoco-doublependulum': 4002, 'mujoco-halfcheetah': 4000, 'mujoco-hopper': 4931, 'mujoco-humanoid': 4119, 'mujoco-pendulum': 4959, 'mujoco-pusher': 9000, 'mujoco-reacher': 9000, 'mujoco-standup': 4000, 'mujoco-swimmer': 4000, 'mujoco-walker': 4101} + return last_row_idx[task] + +def get_last_row_for_100k_states(task): + last_row_idx = {'atari-alien': 3135, 'atari-amidar': 3142, 'atari-assault': 3132, 'atari-asterix': 3181, 'atari-asteroids': 3127, 'atari-atlantis': 3128, 'atari-bankheist': 3156, 'atari-battlezone': 3136, 'atari-beamrider': 3131, 'atari-berzerk': 3127, 'atari-bowling': 3148, 'atari-boxing': 3227, 'atari-breakout': 3128, 'atari-centipede': 3176, 'atari-choppercommand': 3144, 'atari-crazyclimber': 3134, 'atari-defender': 3127, 'atari-demonattack': 3127, 'atari-doubledunk': 3175, 'atari-enduro': 3126, 'atari-fishingderby': 3155, 'atari-freeway': 3131, 'atari-frostbite': 3146, 'atari-gopher': 3128, 'atari-gravitar': 3202, 'atari-hero': 3144, 'atari-icehockey': 3138, 'atari-jamesbond': 3131, 'atari-kangaroo': 3160, 'atari-krull': 3162, 'atari-kungfumaster': 3143, 'atari-montezumarevenge': 3168, 'atari-mspacman': 3143, 'atari-namethisgame': 3131, 'atari-phoenix': 3127, 'atari-pitfall': 3131, 'atari-pong': 3160, 'atari-privateeye': 3158, 'atari-qbert': 3136, 'atari-riverraid': 3157, 'atari-roadrunner': 3150, 'atari-robotank': 3133, 'atari-seaquest': 3138, 'atari-skiing': 3271, 'atari-solaris': 3129, 'atari-spaceinvaders': 3128, 'atari-stargunner': 3129, 'atari-surround': 3143, 'atari-tennis': 3129, 'atari-timepilot': 3132, 'atari-tutankham': 3127, 'atari-upndown': 3127, 'atari-venture': 3148, 'atari-videopinball': 3130, 'atari-wizardofwor': 3138, 'atari-yarsrevenge': 3129, 'atari-zaxxon': 3133, 'babyai-action-obj-door': 15923, 'babyai-blocked-unlock-pickup': 2919, 'babyai-boss-level-no-unlock': 1217, 'babyai-boss-level': 1159, 'babyai-find-obj-s5': 3345, 'babyai-go-to-door': 18875, 'babyai-go-to-imp-unlock': 923, 'babyai-go-to-local': 18724, 'babyai-go-to-obj-door': 16472, 'babyai-go-to-obj': 20197, 'babyai-go-to-red-ball-grey': 16953, 'babyai-go-to-red-ball-no-dists': 20165, 'babyai-go-to-red-ball': 18730, 'babyai-go-to-red-blue-ball': 16934, 'babyai-go-to-seq': 1439, 'babyai-go-to': 1964, 'babyai-key-corridor': 900, 'babyai-mini-boss-level': 3789, 'babyai-move-two-across-s8n9': 2462, 'babyai-one-room-s8': 16994, 'babyai-open-door': 13565, 'babyai-open-doors-order-n4': 9706, 'babyai-open-red-door': 21185, 'babyai-open-two-doors': 7348, 'babyai-open': 3331, 'babyai-pickup-above': 3392, 'babyai-pickup-dist': 19693, 'babyai-pickup-loc': 16405, 'babyai-pickup': 1806, 'babyai-put-next-local': 8303, 'babyai-put-next': 5703, 'babyai-synth-loc': 2183, 'babyai-synth-seq': 1316, 'babyai-synth': 1964, 'babyai-unblock-pickup': 1886, 'babyai-unlock-local': 7118, 'babyai-unlock-pickup': 5107, 'babyai-unlock-to-unlock': 2309, 'babyai-unlock': 1177, 'metaworld-assembly': 1000, 'metaworld-basketball': 1000, 'metaworld-bin-picking': 1000, 'metaworld-box-close': 1000, 'metaworld-button-press-topdown-wall': 1000, 'metaworld-button-press-topdown': 1000, 'metaworld-button-press-wall': 1000, 'metaworld-button-press': 1000, 'metaworld-coffee-button': 1000, 'metaworld-coffee-pull': 1000, 'metaworld-coffee-push': 1000, 'metaworld-dial-turn': 1000, 'metaworld-disassemble': 1000, 'metaworld-door-close': 1000, 'metaworld-door-lock': 1000, 'metaworld-door-open': 1000, 'metaworld-door-unlock': 1000, 'metaworld-drawer-close': 1000, 'metaworld-drawer-open': 1000, 'metaworld-faucet-close': 1000, 'metaworld-faucet-open': 1000, 'metaworld-hammer': 1000, 'metaworld-hand-insert': 1000, 'metaworld-handle-press-side': 1000, 'metaworld-handle-press': 1000, 'metaworld-handle-pull-side': 1000, 'metaworld-handle-pull': 1000, 'metaworld-lever-pull': 1000, 'metaworld-peg-insert-side': 1000, 'metaworld-peg-unplug-side': 1000, 'metaworld-pick-out-of-hole': 1000, 'metaworld-pick-place-wall': 1000, 'metaworld-pick-place': 1000, 'metaworld-plate-slide-back-side': 1000, 'metaworld-plate-slide-back': 1000, 'metaworld-plate-slide-side': 1000, 'metaworld-plate-slide': 1000, 'metaworld-push-back': 1000, 'metaworld-push-wall': 1000, 'metaworld-push': 1000, 'metaworld-reach-wall': 1000, 'metaworld-reach': 1000, 'metaworld-shelf-place': 1000, 'metaworld-soccer': 1000, 'metaworld-stick-pull': 1000, 'metaworld-stick-push': 1000, 'metaworld-sweep-into': 1000, 'metaworld-sweep': 1000, 'metaworld-window-close': 1000, 'metaworld-window-open': 1000, 'mujoco-ant': 401, 'mujoco-doublependulum': 401, 'mujoco-halfcheetah': 400, 'mujoco-hopper': 491, 'mujoco-humanoid': 415, 'mujoco-pendulum': 495, 'mujoco-pusher': 1000, 'mujoco-reacher': 2000, 'mujoco-standup': 400, 'mujoco-swimmer': 400, 'mujoco-walker': 407} + return last_row_idx[task] + def get_obs_dim(task): assert task.startswith("babyai") or task.startswith("metaworld") or task.startswith("mujoco") all_obs_dims={'babyai-action-obj-door': 212, 'babyai-blocked-unlock-pickup': 212, 'babyai-boss-level-no-unlock': 212, 'babyai-boss-level': 212, 'babyai-find-obj-s5': 212, 'babyai-go-to-door': 212, 'babyai-go-to-imp-unlock': 212, 'babyai-go-to-local': 212, 'babyai-go-to-obj-door': 212, 'babyai-go-to-obj': 212, 'babyai-go-to-red-ball-grey': 212, 'babyai-go-to-red-ball-no-dists': 212, 'babyai-go-to-red-ball': 212, 'babyai-go-to-red-blue-ball': 212, 'babyai-go-to-seq': 212, 'babyai-go-to': 212, 'babyai-key-corridor': 212, 'babyai-mini-boss-level': 212, 'babyai-move-two-across-s8n9': 212, 'babyai-one-room-s8': 212, 'babyai-open-door': 212, 'babyai-open-doors-order-n4': 212, 'babyai-open-red-door': 212, 'babyai-open-two-doors': 212, 'babyai-open': 212, 'babyai-pickup-above': 212, 'babyai-pickup-dist': 212, 'babyai-pickup-loc': 212, 'babyai-pickup': 212, 'babyai-put-next-local': 212, 'babyai-put-next': 212, 'babyai-synth-loc': 212, 'babyai-synth-seq': 212, 'babyai-synth': 212, 'babyai-unblock-pickup': 212, 'babyai-unlock-local': 212, 'babyai-unlock-pickup': 212, 'babyai-unlock-to-unlock': 212, 'babyai-unlock': 212, 'metaworld-assembly': 39, 'metaworld-basketball': 39, 'metaworld-bin-picking': 39, 'metaworld-box-close': 39, 'metaworld-button-press-topdown-wall': 39, 'metaworld-button-press-topdown': 39, 'metaworld-button-press-wall': 39, 'metaworld-button-press': 39, 'metaworld-coffee-button': 39, 'metaworld-coffee-pull': 39, 'metaworld-coffee-push': 39, 'metaworld-dial-turn': 39, 'metaworld-disassemble': 39, 'metaworld-door-close': 39, 'metaworld-door-lock': 39, 'metaworld-door-open': 39, 'metaworld-door-unlock': 39, 'metaworld-drawer-close': 39, 'metaworld-drawer-open': 39, 'metaworld-faucet-close': 39, 'metaworld-faucet-open': 39, 'metaworld-hammer': 39, 'metaworld-hand-insert': 39, 'metaworld-handle-press-side': 39, 'metaworld-handle-press': 39, 'metaworld-handle-pull-side': 39, 'metaworld-handle-pull': 39, 'metaworld-lever-pull': 39, 'metaworld-peg-insert-side': 39, 'metaworld-peg-unplug-side': 39, 'metaworld-pick-out-of-hole': 39, 'metaworld-pick-place-wall': 39, 'metaworld-pick-place': 39, 'metaworld-plate-slide-back-side': 39, 'metaworld-plate-slide-back': 39, 'metaworld-plate-slide-side': 39, 'metaworld-plate-slide': 39, 'metaworld-push-back': 39, 'metaworld-push-wall': 39, 'metaworld-push': 39, 'metaworld-reach-wall': 39, 'metaworld-reach': 39, 'metaworld-shelf-place': 39, 'metaworld-soccer': 39, 'metaworld-stick-pull': 39, 'metaworld-stick-push': 39, 'metaworld-sweep-into': 39, 'metaworld-sweep': 39, 'metaworld-window-close': 39, 'metaworld-window-open': 39, 'mujoco-ant': 27, 'mujoco-doublependulum': 11, 'mujoco-halfcheetah': 17, 'mujoco-hopper': 11, 'mujoco-humanoid': 376, 'mujoco-pendulum': 4, 'mujoco-pusher': 23, 'mujoco-reacher': 11, 'mujoco-standup': 376, 'mujoco-swimmer': 8, 'mujoco-walker': 17} - return all_obs_dims[task] + return (all_obs_dims[task],) def get_act_dim(task): assert task.startswith("babyai") or task.startswith("metaworld") or task.startswith("mujoco") @@ -36,141 +48,188 @@ def get_act_dim(task): elif task.startswith("mujoco"): all_act_dims={'mujoco-ant': 8, 'mujoco-doublependulum': 1, 'mujoco-halfcheetah': 6, 'mujoco-hopper': 3, 'mujoco-humanoid': 17, 'mujoco-pendulum': 1, 'mujoco-pusher': 7, 'mujoco-reacher': 2, 'mujoco-standup': 17, 'mujoco-swimmer': 2, 'mujoco-walker': 6} return all_act_dims[task] - -def process_row_atari(attn_mask, row_of_obs, task): - """ - Example for selection with bools: - >>> a = np.array([0,1,2,3,4,5]) - >>> b = np.array([1,0,0,0,0,1]).astype(bool) - >>> a[b] - array([0, 5]) - """ - attn_mask = np.array(attn_mask).astype(bool) - row_of_obs = torch.stack([to_tensor(np.array(img)) for img in row_of_obs]) - row_of_obs = row_of_obs[attn_mask] +def get_task_info(task): + rew_key = 'rewards' + attn_key = 'attention_mask' + if task.startswith("atari"): + obs_key = 'image_observations' + act_key = 'discrete_actions' + B = 32 # half of 54 + obs_dim = (3, 4*84, 84) + elif task.startswith("babyai"): + obs_key = 'discrete_observations' # also has 'text_observations' only for raw dataset not for tokenized dataset (as it is combined into discrete_observation in tokenized dataset) + act_key = 'discrete_actions' + B = 256 # half of 512 + obs_dim = get_obs_dim(task) + elif task.startswith("metaworld") or task.startswith("mujoco"): + obs_key = 'continuous_observations' + act_key = 'continuous_actions' + B = 256 + obs_dim = get_obs_dim(task) + + return rew_key, attn_key, obs_key, act_key, B, obs_dim + +def process_row_of_obs_atari_full_without_mask(row_of_obs): + + if not isinstance(row_of_obs, torch.Tensor): + row_of_obs = torch.stack([to_tensor(np.array(img)) for img in row_of_obs]) row_of_obs = row_of_obs * 0.5 + 0.5 # denormalize from [-1, 1] to [0, 1] - assert row_of_obs.shape == (sum(attn_mask), 84, 4, 84) + assert row_of_obs.shape == (len(row_of_obs), 84, 4, 84) row_of_obs = row_of_obs.permute(0, 2, 1, 3) # (*, 4, 84, 84) - row_of_obs = row_of_obs.reshape(sum(attn_mask), 4*84, 84) # put side-by-side + row_of_obs = row_of_obs.reshape(len(row_of_obs), 4*84, 84) # put side-by-side row_of_obs = row_of_obs.unsqueeze(1).repeat(1, 3, 1, 1) # repeat for 3 channels - assert row_of_obs.shape == (sum(attn_mask), 3, 4*84, 84) # sum(attn_mask) is the batch size dimension - - return attn_mask, row_of_obs + assert row_of_obs.shape == (len(row_of_obs), 3, 4*84, 84) # sum(attn_mask) is the batch size dimension + + return row_of_obs -def process_row_vector(attn_mask, row_of_obs, task, return_numpy=False): - attn_mask = np.array(attn_mask).astype(bool) +def collect_all_atari_data(dataset, all_row_idxs=None): + if all_row_idxs is None: + all_row_idxs = list(range(len(dataset['train']))) - row_of_obs = np.array(row_of_obs) - if not return_numpy: - row_of_obs = torch.tensor(row_of_obs) - row_of_obs = row_of_obs[attn_mask] - assert row_of_obs.shape == (sum(attn_mask), get_obs_dim(task)) - - return attn_mask, row_of_obs - -def retrieve_atari(row_of_obs, # query: (row_B, 3, 4*84, 84) - dataset, # to retrieve from - all_rows_to_consider, # rows to consider - num_to_retrieve, # top-k + all_rows_of_obs = [] + all_attn_masks = [] + for row_idx in tqdm(all_row_idxs): + datarow = dataset['train'][row_idx] + row_of_obs = process_row_of_obs_atari_full_without_mask(datarow['image_observations']) + attn_mask = np.array(datarow['attention_mask']).astype(bool) + all_rows_of_obs.append(row_of_obs) # appending tensor + all_attn_masks.append(attn_mask) # appending np array + all_rows_of_obs = torch.stack(all_rows_of_obs, dim=0) # stacking tensors + all_attn_masks = np.stack(all_attn_masks, axis=0) # concatenating np arrays + assert (all_rows_of_obs.shape == (len(all_row_idxs), 32, 3, 4*84, 84) and + all_attn_masks.shape == (len(all_row_idxs), 32)) + return all_attn_masks, all_rows_of_obs + +def collect_all_data(dataset, task, obs_key): + last_row_idx = get_last_row_for_100k_states(task) + all_row_idxs = list(range(last_row_idx)) + if task.startswith("atari"): + myprint("Collecting all Atari images and Atari attention masks...") + all_attn_masks_OG, all_rows_of_obs_OG = collect_all_atari_data(dataset, all_row_idxs) + else: + datarows = dataset['train'][all_row_idxs] + all_rows_of_obs_OG = np.array(datarows[obs_key]) + all_attn_masks_OG = np.array(datarows['attention_mask']).astype(bool) + return all_rows_of_obs_OG, all_attn_masks_OG, all_row_idxs + +def collect_subset(all_rows_of_obs_OG, + all_attn_masks_OG, + all_rows_to_consider, + kwargs + ): + """ + Function to collect subset of data given all_rows_to_consider, reshape it, create all_indices and return. + Used in both retrieve_atari() and retrieve_vector() --> build_index_vector(). + """ + myprint(f'\n\n\n' + ('-'*100) + f'Collecting subset...') + # read kwargs + B, task, obs_dim = kwargs['B'], kwargs['task'], kwargs['obs_dim'] + + # take subset based on all_rows_to_consider + myprint(f'Taking subset of data based on all_rows_to_consider...') + all_processed_rows_of_obs = all_rows_of_obs_OG[all_rows_to_consider] + all_attn_masks = all_attn_masks_OG[all_rows_to_consider] + assert (all_processed_rows_of_obs.shape == (len(all_rows_to_consider), B, *obs_dim) and + all_attn_masks.shape == (len(all_rows_to_consider), B)) + + # reshape + myprint(f'Reshaping data...') + all_attn_masks = all_attn_masks.reshape(-1) + all_processed_rows_of_obs = all_processed_rows_of_obs.reshape(-1, *obs_dim) + all_processed_rows_of_obs = all_processed_rows_of_obs[all_attn_masks] + assert (all_attn_masks.shape == (len(all_rows_to_consider) * B,) and + all_processed_rows_of_obs.shape == (np.sum(all_attn_masks), *obs_dim)) + + # collect indices of data + myprint(f'Collecting indices of data...') + all_indices = np.array([[row_idx, i] for row_idx in all_rows_to_consider for i in range(B)]) + all_indices = all_indices[all_attn_masks] # this is fine because all attn masks have 0s that only come after 1s + assert all_indices.shape == (np.sum(all_attn_masks), 2) + + myprint(f'{all_indices.shape=}, {all_processed_rows_of_obs.shape=}') + myprint(('-'*100) + '\n\n\n') + return all_indices, all_processed_rows_of_obs + +def retrieve_atari(row_of_obs, # query: (xbdim, 3, 4*84, 84) / (xdim *obs_dim) + all_processed_rows_of_obs, + all_indices, + num_to_retrieve, kwargs - ): + ): + """ + Retrieval for Atari with images, ssim distance, and on GPU. + """ assert isinstance(row_of_obs, torch.Tensor) # read kwargs # Note: B = len of row - B, attn_key, obs_key, device, task, batch_size_retrieval = kwargs['B'], kwargs['attn_key'], kwargs['obs_key'], kwargs['device'], kwargs['task'], kwargs['batch_size_retrieval'] + B, device, batch_size_retrieval = kwargs['B'], kwargs['device'], kwargs['batch_size_retrieval'] # batch size of row_of_obs which can be <= B since we process before calling this function - row_B = row_of_obs.shape[0] - + xbdim = row_of_obs.shape[0] + + # collect subset of data that we can retrieve from + ydim = all_processed_rows_of_obs.shape[0] + # first argument for ssim - repeated_row_og = row_of_obs.repeat_interleave(B, dim=0).to(device) - assert repeated_row_og.shape == (row_B*B, 3, 4*84, 84) + xbatch = row_of_obs.repeat_interleave(batch_size_retrieval, dim=0).to(device) + assert xbatch.shape == (xbdim * batch_size_retrieval, 3, 4*84, 84) - # iterate over all other rows + # iterate over data that we can retrieve from in batches all_ssim = [] - all_indices = [] - total = 0 - for other_row_idx in tqdm(all_rows_to_consider): - other_attn_mask, other_row_of_obs = process_row_atari(dataset['train'][other_row_idx][attn_key], dataset['train'][other_row_idx][obs_key]) - - # batch size of other_row_of_obs - other_row_B = other_row_of_obs.shape[0] - total += other_row_B - - # first argument for ssim: RECHECK - if other_row_B < B: # when other row has less observations than expected - repeated_row = row_of_obs.repeat_interleave(other_row_B, dim=0).to(device) - elif other_row_B == B: # otherwise just use the one created before the for loop - repeated_row = repeated_row_og - assert repeated_row.shape == (row_B*other_row_B, 3, 4*84, 84) - + for j in range(0, ydim, batch_size_retrieval): # second argument for ssim - repeated_other_row = other_row_of_obs.repeat(row_B, 1, 1, 1).to(device) - assert repeated_other_row.shape == (row_B*other_row_B, 3, 4*84, 84) + ybatch = all_processed_rows_of_obs[j:j+batch_size_retrieval] + ybdim = ybatch.shape[0] + ybatch = ybatch.repeat(xbdim, 1, 1, 1).to(device) + assert ybatch.shape == (ybdim * xbdim, 3, 4*84, 84) + + if ybdim < batch_size_retrieval: # for last batch + xbatch = row_of_obs.repeat_interleave(ybdim, dim=0).to(device) + assert xbatch.shape == (xbdim * ybdim, 3, 4*84, 84) # compare via ssim and updated all_ssim - ssim_score = ssim(repeated_row, repeated_other_row, data_range=1.0, size_average=False) - ssim_score = ssim_score.reshape(row_B, other_row_B) + ssim_score = ssim(xbatch, ybatch, data_range=1.0, size_average=False) + ssim_score = ssim_score.reshape(xbdim, ybdim) all_ssim.append(ssim_score) - # update all_indices - all_indices.extend([[other_row_idx, i] for i in range(other_row_B)]) - # concat all_ssim = torch.cat(all_ssim, dim=1) - assert all_ssim.shape == (row_B, total) + assert all_ssim.shape == (xbdim, ydim) - all_indices = np.array(all_indices) - assert all_indices.shape == (total, 2) + assert all_indices.shape == (ydim, 2) # get top-k indices topk_values, topk_indices = torch.topk(all_ssim, num_to_retrieve, dim=1, largest=True) topk_indices = topk_indices.cpu().numpy() - assert topk_indices.shape == (row_B, num_to_retrieve) + assert topk_indices.shape == (xbdim, num_to_retrieve) # convert topk indices to indices in the dataset - retrieved_indices = np.array(all_indices[topk_indices]) - assert retrieved_indices.shape == (row_B, num_to_retrieve, 2) - - # pad the above to expected B - if row_B < B: - retrieved_indices = np.concatenate([retrieved_indices, np.zeros((B-row_B, num_to_retrieve, 2), dtype=int)], axis=0) - assert retrieved_indices.shape == (B, num_to_retrieve, 2) + retrieved_indices = all_indices[topk_indices] + assert retrieved_indices.shape == (xbdim, num_to_retrieve, 2) return retrieved_indices -def build_index_vector(all_rows_of_obs_og, - all_attn_masks_og, +def build_index_vector(all_rows_of_obs_OG, + all_attn_masks_OG, all_rows_to_consider, kwargs - ): + ): + """ + Builds FAISS index for vector observation environments. + """ # read kwargs # Note: B = len of row - B, attn_key, obs_key, device, task, batch_size_retrieval, nb_cores_autofaiss = kwargs['B'], kwargs['attn_key'], kwargs['obs_key'], kwargs['device'], kwargs['task'], kwargs['batch_size_retrieval'], kwargs['nb_cores_autofaiss'] - obs_dim = get_obs_dim(task) + nb_cores_autofaiss = kwargs['nb_cores_autofaiss'] - # take subset based on all_rows_to_consider - myprint(f'Taking subset') - all_rows_of_obs = all_rows_of_obs_og[all_rows_to_consider] - all_attn_masks = all_attn_masks_og[all_rows_to_consider] - assert (all_rows_of_obs.shape == (len(all_rows_to_consider), B, obs_dim) and - all_attn_masks.shape == (len(all_rows_to_consider), B)) - - # reshape - all_attn_masks = all_attn_masks.reshape(-1) - all_rows_of_obs = all_rows_of_obs.reshape(-1, obs_dim) - all_rows_of_obs = all_rows_of_obs[all_attn_masks] - assert all_rows_of_obs.shape == (np.sum(all_attn_masks), obs_dim) + # take subset based on all_rows_to_consider, reshape, and save indices of data + all_indices, all_processed_rows_of_obs = collect_subset(all_rows_of_obs_OG, all_attn_masks_OG, all_rows_to_consider, kwargs) - # save indices of data to retrieve from - myprint(f'Saving indices of data to retrieve from') - all_indices = np.array([[row_idx, i] for row_idx in all_rows_to_consider for i in range(B)]) - all_indices = all_indices[all_attn_masks] # this is fine because all attn masks have 0s that only come after 1s - assert all_indices.shape == (np.sum(all_attn_masks), 2) + # make sure input to build_index is float, otherwise you will get reading temp file error + all_processed_rows_of_obs = all_processed_rows_of_obs.astype(float) # build index - myprint(f'Building index...') - knn_index, knn_index_infos = build_index(embeddings=all_rows_of_obs, # Note: embeddings have to be float to avoid errors in autofaiss / embedding_reader! + myprint(('-'*100) + 'Building index...') + knn_index, knn_index_infos = build_index(embeddings=all_processed_rows_of_obs, # Note: embeddings have to be float to avoid errors in autofaiss / embedding_reader! save_on_disk=False, min_nearest_neighbors_to_retrieve=20, # default: 20 max_index_query_time_ms=10, # default: 10 @@ -179,34 +238,32 @@ def build_index_vector(all_rows_of_obs_og, metric_type='l2', nb_cores=nb_cores_autofaiss, # default: None # "The number of cores to use, by default will use all cores" as seen in https://criteo.github.io/autofaiss/getting_started/quantization.html#the-build-index-command ) + myprint(('-'*100) + '\n\n\n') - return knn_index, all_indices + return all_indices, knn_index -def retrieve_vector(row_of_obs, # query: (row_B, dim) - dataset, # to retrieve from - all_rows_to_consider, # rows to consider - num_to_retrieve, # top-k +def retrieve_vector(row_of_obs, # query: (xbdim, *obs_dim) + knn_index, + all_indices, + num_to_retrieve, kwargs - ): + ): + """ + Retrieval for vector observation environments. + """ assert isinstance(row_of_obs, np.ndarray) # read few kwargs B = kwargs['B'] # batch size of row_of_obs which can be <= B since we process before calling this function - row_B = row_of_obs.shape[0] + xbdim = row_of_obs.shape[0] - # read dataset_tuple - all_rows_of_obs, all_attn_masks = dataset - - # create index and all_indices - knn_index, all_indices = build_index_vector(all_rows_of_obs, all_attn_masks, all_rows_to_consider, kwargs) - # retrieve myprint(f'Retrieving...') topk_indices, _ = knn_index.search(row_of_obs, 10 * num_to_retrieve) topk_indices = topk_indices.astype(int) - assert topk_indices.shape == (row_B, 10 * num_to_retrieve) + assert topk_indices.shape == (xbdim, 10 * num_to_retrieve) # remove -1s and crop to num_to_retrieve try: @@ -219,16 +276,10 @@ def retrieve_vector(row_of_obs, # query: (row_B, dim) print(f'-------------------------------------------------------------------------------------------------------------------------------------------') print(f'Leaving some -1s in topk_indices and continuing') topk_indices = np.array([indices[:num_to_retrieve] for indices in topk_indices]) - assert topk_indices.shape == (row_B, num_to_retrieve) + assert topk_indices.shape == (xbdim, num_to_retrieve) # convert topk indices to indices in the dataset retrieved_indices = all_indices[topk_indices] - assert retrieved_indices.shape == (row_B, num_to_retrieve, 2) - - # pad the above to expected B - if row_B < B: - retrieved_indices = np.concatenate([retrieved_indices, np.zeros((B-row_B, num_to_retrieve, 2), dtype=int)], axis=0) - assert retrieved_indices.shape == (B, num_to_retrieve, 2) + assert retrieved_indices.shape == (xbdim, num_to_retrieve, 2) - myprint(f'Returning') return retrieved_indices \ No newline at end of file diff --git a/scripts_regent/eval_RandP.py b/scripts_regent/eval_RandP.py index 07e545c..146b347 100755 --- a/scripts_regent/eval_RandP.py +++ b/scripts_regent/eval_RandP.py @@ -15,9 +15,10 @@ from transformers import AutoModelForCausalLM, AutoProcessor, HfArgumentParser from jat.eval.rl import TASK_NAME_TO_ENV_ID, make from jat.utils import normalize, push_to_hub, save_video_grid -from jat_regent.RandP import RandP +from jat_regent.modeling_RandP import RandP from datasets import load_from_disk from datasets.config import HF_DATASETS_CACHE +from jat_regent.utils import myprint @dataclass @@ -70,6 +71,7 @@ def eval_rl(model, processor, task, eval_args): scores = [] frames = [] for episode in tqdm(range(eval_args.num_episodes), desc=task, unit="episode", leave=False): + myprint(('-'*100) + f'{episode=}') observation, _ = env.reset() reward = None rewards = [] @@ -96,6 +98,7 @@ def eval_rl(model, processor, task, eval_args): frames.append(np.array(env.render(), dtype=np.uint8)) scores.append(sum(rewards)) + myprint(('-'*100) + '\n\n\n') env.close() raw_mean, raw_std = np.mean(scores), np.std(scores) @@ -145,7 +148,9 @@ def main(): tasks.extend([env_id for env_id in TASK_NAME_TO_ENV_ID.keys() if env_id.startswith(domain)]) device = torch.device("cpu") if eval_args.use_cpu else get_default_device() - processor = None + processor = AutoProcessor.from_pretrained( + 'jat-project/jat', cache_dir=None, trust_remote_code=True + ) evaluations = {} video_list = [] @@ -153,14 +158,18 @@ def main(): for task in tqdm(tasks, desc="Evaluation", unit="task", leave=True): if task in TASK_NAME_TO_ENV_ID.keys(): + myprint(('-'*100) + f'{task=}') dataset = load_from_disk(f'{HF_DATASETS_CACHE}/jat-project/jat-dataset-tokenized/{task}') - model = RandP(dataset) + model = RandP(task, + dataset, + device,) scores, frames, fps = eval_rl(model, processor, task, eval_args) evaluations[task] = scores # Save the video if eval_args.save_video: video_list.append(frames) input_fps.append(fps) + myprint(('-'*100) + '\n\n\n') else: warnings.warn(f"Task {task} is not supported.") diff --git a/scripts_regent/offline_retrieval_jat_regent.py b/scripts_regent/offline_retrieval_jat_regent.py index c83d259..aad678a 100644 --- a/scripts_regent/offline_retrieval_jat_regent.py +++ b/scripts_regent/offline_retrieval_jat_regent.py @@ -8,7 +8,7 @@ import time from datetime import datetime from datasets import load_from_disk from datasets.config import HF_DATASETS_CACHE -from jat_regent.utils import myprint, process_row_atari, process_row_vector, retrieve_atari, retrieve_vector +from jat_regent.utils import myprint, get_task_info, collect_all_data, process_row_of_obs_atari_full_without_mask, retrieve_atari, retrieve_vector, collect_subset, build_index_vector import logging logging.basicConfig(level=logging.DEBUG) @@ -17,7 +17,8 @@ def main(): parser = argparse.ArgumentParser(description='Build RAAGENT sequence indices') parser.add_argument('--task', type=str, default='atari-alien', help='Task name') parser.add_argument('--num_to_retrieve', type=int, default=100, help='Number of states/windows to retrieve') - parser.add_argument('--nb_cores_autofaiss', type=int, default=16, help='Number of cores to use for faiss in vector observation environments') + parser.add_argument('--nb_cores_autofaiss', type=int, default=16, help='Number of cores to use for faiss in vector obs envs') + parser.add_argument('--batch_size_retrieval', type=int, default=1024, help='Batch size for retrieval in atari') args = parser.parse_args() # load dataset, map, device, for task @@ -25,77 +26,83 @@ def main(): dataset_path = f"{HF_DATASETS_CACHE}/jat-project/jat-dataset-tokenized/{task}" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - rew_key = 'rewards' - attn_key = 'attention_mask' - if task.startswith("atari"): - obs_key = 'image_observations' - act_key = 'discrete_actions' - len_row_tokenized_known = 32 # half of 54 - process_row_fn = process_row_atari - retrieve_fn = retrieve_atari - elif task.startswith("babyai"): - obs_key = 'discrete_observations'# also has 'text_observations' only for raw dataset not for tokenized dataset (as it is combined into discrete_observation in tokenized dataset) - act_key = 'discrete_actions' - len_row_tokenized_known = 256 # half of 512 - process_row_fn = lambda attn_mask, row_of_obs, task: process_row_vector(attn_mask, row_of_obs, task, return_numpy=True) - retrieve_fn = retrieve_vector - elif task.startswith("metaworld") or task.startswith("mujoco"): - obs_key = 'continuous_observations' - act_key = 'continuous_actions' - len_row_tokenized_known = 256 - process_row_fn = lambda attn_mask, row_of_obs, task: process_row_vector(attn_mask, row_of_obs, task, return_numpy=True) - retrieve_fn = retrieve_vector + rew_key, attn_key, obs_key, act_key, B, obs_dim = get_task_info(task) dataset = load_from_disk(dataset_path) with open(f"{dataset_path}/map_from_rows_to_episodes_for_tokenized.json", 'r') as f: map_from_rows_to_episodes_for_tokenized = json.load(f) # setup kwargs - len_dataset = len(dataset['train']) - B = len_row_tokenized_known kwargs = {'B': B, - 'attn_key':attn_key, - 'obs_key':obs_key, - 'device':device, - 'task':task, - 'batch_size_retrieval':None, - 'nb_cores_autofaiss':None if task.startswith("atari") else args.nb_cores_autofaiss, - } + 'obs_dim': obs_dim, + 'attn_key': attn_key, + 'obs_key': obs_key, + 'device': device, + 'task': task, + 'batch_size_retrieval': args.batch_size_retrieval, + 'nb_cores_autofaiss': None if task.startswith("atari") else args.nb_cores_autofaiss, + } # collect all observations in a single array (this takes some time) for vector observation environments - if not task.startswith("atari"): - myprint("Collecting all observations/attn_masks in a single array") - all_rows_of_obs = np.array(dataset['train'][obs_key]) - all_attn_masks = np.array(dataset['train'][attn_key]).astype(bool) + myprint("Collecting all observations/attn_masks") + all_rows_of_obs_OG, all_attn_masks_OG, all_row_idxs = collect_all_data(dataset, task, obs_key) # iterate over rows all_retrieved_indices = [] - for row_idx in range(len_dataset): - myprint(f"\nProcessing row {row_idx}/{len_dataset}") + for row_idx in all_row_idxs: + myprint(f"\nProcessing row {row_idx}/{len(all_row_idxs)}") current_ep = map_from_rows_to_episodes_for_tokenized[str(row_idx)] - attn_mask, row_of_obs = process_row_fn(dataset['train'][row_idx][attn_key], dataset['train'][row_idx][obs_key], task) + # get row_of_obs and attn_mask + datarow = dataset['train'][row_idx] + attn_mask = np.array(datarow[attn_key]).astype(bool) + if task.startswith("atari"): + row_of_obs = process_row_of_obs_atari_full_without_mask(datarow[obs_key]) + else: + row_of_obs = np.array(datarow[obs_key]) + row_of_obs = row_of_obs[attn_mask] + assert row_of_obs.shape == (np.sum(attn_mask), *obs_dim) # compare with rows from all but the current episode - all_other_rows = [idx for idx in range(len_dataset) if map_from_rows_to_episodes_for_tokenized[str(idx)] != current_ep] + all_other_row_idxs = [idx for idx in all_row_idxs if map_from_rows_to_episodes_for_tokenized[str(idx)] != current_ep] # do the retrieval - retrieved_indices = retrieve_fn(row_of_obs=row_of_obs, - dataset=dataset if task.startswith("atari") else (all_rows_of_obs, all_attn_masks), - all_rows_to_consider=all_other_rows, - num_to_retrieve=args.num_to_retrieve, - kwargs=kwargs, - ) + if task.startswith("atari"): + all_indices, all_processed_rows_of_obs = collect_subset(all_rows_of_obs_OG=all_rows_of_obs_OG, + all_attn_masks_OG=all_attn_masks_OG, + all_rows_to_consider=all_row_idxs, + kwargs=kwargs) + retrieved_indices = retrieve_atari(row_of_obs=row_of_obs, + all_processed_rows_of_obs=all_processed_rows_of_obs, + all_indices=all_indices, + num_to_retrieve=args.num_to_retrieve, + kwargs=kwargs) + else: + all_indices, knn_index = build_index_vector(all_rows_of_obs_OG=all_rows_of_obs_OG, + all_attn_masks_OG=all_attn_masks_OG, + all_rows_to_consider=all_other_row_idxs, + kwargs=kwargs) + retrieved_indices = retrieve_vector(row_of_obs=row_of_obs, + knn_index=knn_index, + all_indices=all_indices, + num_to_retrieve=args.num_to_retrieve, + kwargs=kwargs) + + # pad the above to expected B + xbdim = row_of_obs.shape[0] + if xbdim < B: + retrieved_indices = np.concatenate([retrieved_indices, np.zeros((B-xbdim, args.num_to_retrieve, 2), dtype=int)], axis=0) + assert retrieved_indices.shape == (B, args.num_to_retrieve, 2) # collect retrieved indices all_retrieved_indices.append(retrieved_indices) # concat all_retrieved_indices = np.stack(all_retrieved_indices, axis=0) - assert all_retrieved_indices.shape == (len_dataset, B, args.num_to_retrieve, 2) + assert all_retrieved_indices.shape == (len(all_row_idxs), B, args.num_to_retrieve, 2) # save arrays as bin for easy memmap access and faster loading - all_retrieved_indices.tofile(f"{dataset_path}/retrieved_indices_{len_dataset}_{B}_{args.num_to_retrieve}_2.bin") + all_retrieved_indices.tofile(f"{dataset_path}/retrieved_indices_{len(all_row_idxs)}_{B}_{args.num_to_retrieve}_2.bin") if __name__ == "__main__": main() \ No newline at end of file