diff --git a/data/envs/metaworld/generate_dataset.py b/data/envs/metaworld/generate_dataset.py index dc51b4f..b189348 100644 --- a/data/envs/metaworld/generate_dataset.py +++ b/data/envs/metaworld/generate_dataset.py @@ -4,7 +4,9 @@ from typing import Dict, Optional import gym import metaworld import numpy as np +import pandas as pd import torch +from datasets import Dataset from huggingface_hub import HfApi, repocard, upload_folder from sample_factory.algo.learning.learner import Learner from sample_factory.algo.sampling.batched_sampling import preprocess_actions @@ -12,11 +14,7 @@ from sample_factory.algo.utils.action_distributions import argmax_actions from sample_factory.algo.utils.env_info import extract_env_info from sample_factory.algo.utils.make_env import make_env_func_batched from sample_factory.algo.utils.rl_utils import make_dones, prepare_and_normalize_obs -from sample_factory.cfg.arguments import ( - load_from_checkpoint, - parse_full_cfg, - parse_sf_args, -) +from sample_factory.cfg.arguments import load_from_checkpoint, parse_full_cfg, parse_sf_args from sample_factory.envs.env_utils import register_env from sample_factory.model.actor_critic import create_actor_critic from sample_factory.model.model_utils import get_rnn_size @@ -206,6 +204,13 @@ def create_dataset(cfg: Config): env.close() + # Convert dict of numpy array to pandas dataframe + dataset = Dataset.from_dict(dataset) + dataset.create_config_id + # Set the card of the dataset + dataset.card = f"""""" + dataset.push_to_hub("qgallouedec/prj_gia_dataset_metaworld_assembly_v2_1111_demo") + # Save dataset repo_path = f"{cfg.train_dir}/datasets/{cfg.experiment}" os.makedirs(repo_path, exist_ok=True)