qgallouedec's picture
qgallouedec HF staff
Upload . with huggingface_hub
b20cb3b
raw
history blame contribute delete
No virus
3.74 kB
diff --git a/data/envs/metaworld/generate_dataset.py b/data/envs/metaworld/generate_dataset.py
index dc51b4f..ed99a4a 100644
--- a/data/envs/metaworld/generate_dataset.py
+++ b/data/envs/metaworld/generate_dataset.py
@@ -4,7 +4,9 @@ from typing import Dict, Optional
import gym
import metaworld
import numpy as np
+import pandas as pd
import torch
+from datasets import Dataset
from huggingface_hub import HfApi, repocard, upload_folder
from sample_factory.algo.learning.learner import Learner
from sample_factory.algo.sampling.batched_sampling import preprocess_actions
@@ -12,11 +14,7 @@ from sample_factory.algo.utils.action_distributions import argmax_actions
from sample_factory.algo.utils.env_info import extract_env_info
from sample_factory.algo.utils.make_env import make_env_func_batched
from sample_factory.algo.utils.rl_utils import make_dones, prepare_and_normalize_obs
-from sample_factory.cfg.arguments import (
- load_from_checkpoint,
- parse_full_cfg,
- parse_sf_args,
-)
+from sample_factory.cfg.arguments import load_from_checkpoint, parse_full_cfg, parse_sf_args
from sample_factory.envs.env_utils import register_env
from sample_factory.model.actor_critic import create_actor_critic
from sample_factory.model.model_utils import get_rnn_size
@@ -165,10 +163,8 @@ def create_dataset(cfg: Config):
# Create dataset
dataset_size = 100_000
dataset = {
- "observations": np.zeros(
- (dataset_size, *env.observation_space["obs"].shape), dtype=env.observation_space["obs"].dtype
- ),
- "actions": np.zeros((dataset_size, *env.action_space.shape), env.action_space.dtype),
+ "observations": np.zeros((dataset_size, *env.observation_space["obs"].shape), dtype=np.float32),
+ "actions": np.zeros((dataset_size, *env.action_space.shape), np.float32),
"dones": np.zeros((dataset_size,), bool),
"rewards": np.zeros((dataset_size,), np.float32),
}
@@ -206,6 +202,13 @@ def create_dataset(cfg: Config):
env.close()
+ # Convert dict of numpy array to pandas dataframe
+# dataset = Dataset.from_dict(dataset)
+# dataset.create_config_id
+ # Set the card of the dataset
+# dataset.card = f""""""
+# dataset.push_to_hub("qgallouedec/prj_gia_dataset_metaworld_assembly_v2_1111_demo")
+
# Save dataset
repo_path = f"{cfg.train_dir}/datasets/{cfg.experiment}"
os.makedirs(repo_path, exist_ok=True)
diff --git a/data/envs/metaworld/generate_dataset_all.sh b/data/envs/metaworld/generate_dataset_all.sh
index 802bf5c..3cc4f97 100755
--- a/data/envs/metaworld/generate_dataset_all.sh
+++ b/data/envs/metaworld/generate_dataset_all.sh
@@ -1,34 +1,6 @@
#!/bin/bash
ENVS=(
- assembly
- basketball
- bin-picking
- box-close
- button-press-topdown
- button-press-topdown-wall
- button-press
- button-press-wall
- coffee-button
- coffee-pull
- coffee-push
- dial-turn
- disassemble
- door-close
- door-lock
- door-open
- door-unlock
- drawer-close
- drawer-open
- faucet-close
- faucet-open
- hammer
- hand-insert
- handle-press-side
- handle-press
- handle-pull-side
- handle-pull
- lever-pull
peg-insert-side
peg-unplug-side
pick-out-of-hole
@@ -40,19 +12,8 @@ ENVS=(
plate-slide
push-back
push
- push-wall
- reach
- reach-wall
- shelf-place
- soccer
- stick-pull
- stick-push
- sweep-into
- sweep
- window-close
- window-open
)
for ENV in "${ENVS[@]}"; do
- python generate_dataset.py --env $ENV-v2 --experiment $ENV-v2 --train_dir=./train_dir
+ python generate_dataset.py --env $ENV-v2 --experiment $ENV-v2 --train_dir=./train_dir
done