|
|
|
|
|
|
|
|
|
@@ -77,10 +77,16 @@ You can also save a video of the model during evaluation to upload to the hub wi |
|
|
|
- `--video_name`: The name of the video to save as. If `None`, will save to `replay.mp4` in your experiment directory |
|
|
|
+Also, you can include information in the Hugging Face Hub model card for how to train and enjoy using this model. These parameters are optional: |
|
+ |
|
+- `--train_script`: The module path for training this model |
|
+ |
|
+- `--enjoy_script`: The module path for enjoying this model |
|
+ |
|
For example: |
|
|
|
``` |
|
-python -m sf_examples.mujoco.enjoy_mujoco --algo=APPO --env=mujoco_ant --experiment=<repo_name> --train_dir=./train_dir --max_num_episodes=10 --push_to_hub --hf_repository=<username>/<hf_repo_name> --save_video --no_render |
|
+python -m sf_examples.mujoco.enjoy_mujoco --algo=APPO --env=mujoco_ant --experiment=<repo_name> --train_dir=./train_dir --max_num_episodes=10 --push_to_hub --hf_repository=<username>/<hf_repo_name> --save_video --no_render --enjoy_script=sf_examples.mujoco.enjoy_mujoco --train_script=sf_examples.mujoco.train_mujoco |
|
``` |
|
|
|
#### Using the push_to_hub Script |
|
@@ -95,4 +101,6 @@ The command line arguments are: |
|
|
|
- `-r`: The repo_id to save on HF Hub. This is the same as `hf_repository` in the enjoy script and must be in the form `<hf_username>/<hf_repo_name>` |
|
|
|
-- `-d`: The full path to your experiment directory to upload |
|
\ No newline at end of file |
|
+- `-d`: The full path to your experiment directory to upload |
|
+ |
|
+The optional arguments of `--train_script` and `--enjoy_script` can also be used. See the above section for more details |
|
\ No newline at end of file |
|
|
|
|
|
|
|
|
|
@@ -18,7 +18,7 @@ from sample_factory.cfg.cfg import ( |
|
) |
|
from sample_factory.utils.attr_dict import AttrDict |
|
from sample_factory.utils.typing import Config |
|
-from sample_factory.utils.utils import cfg_file, cfg_file_old, get_git_commit_hash, get_top_level_script, log |
|
+from sample_factory.utils.utils import cfg_file, cfg_file_old, get_git_commit_hash, log |
|
|
|
|
|
def parse_sf_args( |
|
@@ -91,7 +91,6 @@ def postprocess_args(args, argv, parser) -> argparse.Namespace: |
|
|
|
args.cli_args = vars(cli_args) |
|
args.git_hash, args.git_repo_name = get_git_commit_hash() |
|
- args.train_script = get_top_level_script() |
|
return args |
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -675,6 +675,19 @@ def add_eval_args(parser): |
|
help="False to sample from action distributions at test time. True to just use the argmax.", |
|
) |
|
|
|
+ parser.add_argument( |
|
+ "--train_script", |
|
+ default=None, |
|
+ type=str, |
|
+ help="Module name used to run training script. Used to generate HF model card", |
|
+ ) |
|
+ parser.add_argument( |
|
+ "--enjoy_script", |
|
+ default=None, |
|
+ type=str, |
|
+ help="Module name used to run training script. Used to generate HF model card", |
|
+ ) |
|
+ |
|
|
|
def add_wandb_args(p: ArgumentParser): |
|
"""Weights and Biases experiment monitoring.""" |
|
|
|
|
|
|
|
|
|
@@ -21,7 +21,7 @@ from sample_factory.model.actor_critic import create_actor_critic |
|
from sample_factory.model.model_utils import get_rnn_size |
|
from sample_factory.utils.attr_dict import AttrDict |
|
from sample_factory.utils.typing import Config, StatusCode |
|
-from sample_factory.utils.utils import debug_log_every_n, experiment_dir, get_top_level_script, log |
|
+from sample_factory.utils.utils import debug_log_every_n, experiment_dir, log |
|
|
|
|
|
def visualize_policy_inputs(normalized_obs: Dict[str, Tensor]) -> None: |
|
@@ -260,9 +260,8 @@ def enjoy(cfg: Config) -> Tuple[StatusCode, float]: |
|
generate_replay_video(experiment_dir(cfg=cfg), video_frames, fps) |
|
|
|
if cfg.push_to_hub: |
|
- enjoy_name = get_top_level_script() |
|
generate_model_card( |
|
- experiment_dir(cfg=cfg), cfg.algo, cfg.env, cfg.hf_repository, reward_list, enjoy_name, cfg.train_script |
|
+ experiment_dir(cfg=cfg), cfg.algo, cfg.env, cfg.hf_repository, reward_list, cfg.enjoy_script, cfg.train_script |
|
) |
|
push_to_hf(experiment_dir(cfg=cfg), cfg.hf_repository) |
|
|
|
|
|
|
|
|
|
|
|
@@ -57,8 +57,10 @@ python -m sample_factory.huggingface.load_from_hub -r {repo_id} |
|
```\n |
|
""" |
|
|
|
- if enjoy_name is not None: |
|
- readme += f""" |
|
+ if enjoy_name is None: |
|
+ enjoy_name = "<path.to.enjoy.module>" |
|
+ |
|
+ readme += f""" |
|
## Using the model\n |
|
To run the model after download, use the `enjoy` script corresponding to this environment: |
|
``` |
|
@@ -67,17 +69,19 @@ python -m {enjoy_name} --algo={algo} --env={env} --train_dir=./train_dir --exper |
|
\n |
|
You can also upload models to the Hugging Face Hub using the same script with the `--push_to_hub` flag. |
|
See https://www.samplefactory.dev/10-huggingface/huggingface/ for more details |
|
- """ |
|
+ """ |
|
|
|
- if train_name is not None: |
|
- readme += f""" |
|
+ if train_name is None: |
|
+ train_name = "<path.to.train.module>" |
|
+ |
|
+ readme += f""" |
|
## Training with this model\n |
|
To continue training with this model, use the `train` script corresponding to this environment: |
|
``` |
|
python -m {train_name} --algo={algo} --env={env} --train_dir=./train_dir --experiment={repo_name} --restart_behavior=resume --train_for_env_steps=10000000000 |
|
```\n |
|
Note, you may have to adjust `--train_for_env_steps` to a suitably high number as the experiment will resume at the number of steps it concluded at. |
|
- """ |
|
+ """ |
|
|
|
with open(readme_path, "w", encoding="utf-8") as f: |
|
f.write(readme) |
|
|
|
|
|
|
|
|
|
@@ -16,6 +16,18 @@ def main(): |
|
type=str, |
|
) |
|
parser.add_argument("-d", "--experiment_dir", help="Path to your experiment directory", type=str) |
|
+ parser.add_argument( |
|
+ "--train_script", |
|
+ default=None, |
|
+ type=str, |
|
+ help="Module name used to run training script. Used to generate HF model card", |
|
+ ) |
|
+ parser.add_argument( |
|
+ "--enjoy_script", |
|
+ default=None, |
|
+ type=str, |
|
+ help="Module name used to run training script. Used to generate HF model card", |
|
+ ) |
|
args = parser.parse_args() |
|
|
|
cfg_file = os.path.join(args.experiment_dir, "config.json") |
|
@@ -34,7 +46,7 @@ def main(): |
|
json_params = json.load(json_file) |
|
cfg = AttrDict(json_params) |
|
|
|
- generate_model_card(args.experiment_dir, cfg.algo, cfg.env, args.hf_repository) |
|
+ generate_model_card(args.experiment_dir, cfg.algo, cfg.env, args.hf_repository, enjoy_name=args.enjoy_script, train_name=args.train_script) |
|
push_to_hf(args.experiment_dir, args.hf_repository) |
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -493,5 +493,5 @@ def debug_log_every_n(n, msg, *args, **kwargs): |
|
log_every_n(n, logging.DEBUG, msg, *args, **kwargs) |
|
|
|
|
|
-def get_top_level_script(): |
|
- return argv[0].split("sample-factory/")[-1][:-3].replace("/", ".") |
|
+# def get_top_level_script(): |
|
+# return argv[0].split("sample-factory/")[-1][:-3].replace("/", ".") |
|
|