diff --git a/data/envs/metaworld/generate_dataset.py b/data/envs/metaworld/generate_dataset.py index dc51b4f..ed99a4a 100644 --- a/data/envs/metaworld/generate_dataset.py +++ b/data/envs/metaworld/generate_dataset.py @@ -4,7 +4,9 @@ from typing import Dict, Optional import gym import metaworld import numpy as np +import pandas as pd import torch +from datasets import Dataset from huggingface_hub import HfApi, repocard, upload_folder from sample_factory.algo.learning.learner import Learner from sample_factory.algo.sampling.batched_sampling import preprocess_actions @@ -12,11 +14,7 @@ from sample_factory.algo.utils.action_distributions import argmax_actions from sample_factory.algo.utils.env_info import extract_env_info from sample_factory.algo.utils.make_env import make_env_func_batched from sample_factory.algo.utils.rl_utils import make_dones, prepare_and_normalize_obs -from sample_factory.cfg.arguments import ( - load_from_checkpoint, - parse_full_cfg, - parse_sf_args, -) +from sample_factory.cfg.arguments import load_from_checkpoint, parse_full_cfg, parse_sf_args from sample_factory.envs.env_utils import register_env from sample_factory.model.actor_critic import create_actor_critic from sample_factory.model.model_utils import get_rnn_size @@ -165,10 +163,8 @@ def create_dataset(cfg: Config): # Create dataset dataset_size = 100_000 dataset = { - "observations": np.zeros( - (dataset_size, *env.observation_space["obs"].shape), dtype=env.observation_space["obs"].dtype - ), - "actions": np.zeros((dataset_size, *env.action_space.shape), env.action_space.dtype), + "observations": np.zeros((dataset_size, *env.observation_space["obs"].shape), dtype=np.float32), + "actions": np.zeros((dataset_size, *env.action_space.shape), np.float32), "dones": np.zeros((dataset_size,), bool), "rewards": np.zeros((dataset_size,), np.float32), } @@ -206,6 +202,13 @@ def create_dataset(cfg: Config): env.close() + # Convert dict of numpy array to pandas dataframe +# dataset = Dataset.from_dict(dataset) +# dataset.create_config_id + # Set the card of the dataset +# dataset.card = f"""""" +# dataset.push_to_hub("qgallouedec/prj_gia_dataset_metaworld_assembly_v2_1111_demo") + # Save dataset repo_path = f"{cfg.train_dir}/datasets/{cfg.experiment}" os.makedirs(repo_path, exist_ok=True) diff --git a/data/envs/metaworld/generate_dataset_all.sh b/data/envs/metaworld/generate_dataset_all.sh index 802bf5c..3cc4f97 100755 --- a/data/envs/metaworld/generate_dataset_all.sh +++ b/data/envs/metaworld/generate_dataset_all.sh @@ -1,34 +1,6 @@ #!/bin/bash ENVS=( - assembly - basketball - bin-picking - box-close - button-press-topdown - button-press-topdown-wall - button-press - button-press-wall - coffee-button - coffee-pull - coffee-push - dial-turn - disassemble - door-close - door-lock - door-open - door-unlock - drawer-close - drawer-open - faucet-close - faucet-open - hammer - hand-insert - handle-press-side - handle-press - handle-pull-side - handle-pull - lever-pull peg-insert-side peg-unplug-side pick-out-of-hole @@ -40,19 +12,8 @@ ENVS=( plate-slide push-back push - push-wall - reach - reach-wall - shelf-place - soccer - stick-pull - stick-push - sweep-into - sweep - window-close - window-open ) for ENV in "${ENVS[@]}"; do - python generate_dataset.py --env $ENV-v2 --experiment $ENV-v2 --train_dir=./train_dir + python generate_dataset.py --env $ENV-v2 --experiment $ENV-v2 --train_dir=./train_dir done