File size: 8,119 Bytes
4cffbfd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
diff --git a/docs/10-huggingface/huggingface.md b/docs/10-huggingface/huggingface.md
index 8846da73..1f1fae6f 100644
--- a/docs/10-huggingface/huggingface.md
+++ b/docs/10-huggingface/huggingface.md
@@ -77,10 +77,16 @@ You can also save a video of the model during evaluation to upload to the hub wi
- `--video_name`: The name of the video to save as. If `None`, will save to `replay.mp4` in your experiment directory
+Also, you can include information in the Hugging Face Hub model card for how to train and enjoy using this model. These parameters are optional:
+
+- `--train_script`: The module path for training this model
+
+- `--enjoy_script`: The module path for enjoying this model
+
For example:
```
-python -m sf_examples.mujoco.enjoy_mujoco --algo=APPO --env=mujoco_ant --experiment=<repo_name> --train_dir=./train_dir --max_num_episodes=10 --push_to_hub --hf_repository=<username>/<hf_repo_name> --save_video --no_render
+python -m sf_examples.mujoco.enjoy_mujoco --algo=APPO --env=mujoco_ant --experiment=<repo_name> --train_dir=./train_dir --max_num_episodes=10 --push_to_hub --hf_repository=<username>/<hf_repo_name> --save_video --no_render --enjoy_script=sf_examples.mujoco.enjoy_mujoco --train_script=sf_examples.mujoco.train_mujoco
```
#### Using the push_to_hub Script
@@ -95,4 +101,6 @@ The command line arguments are:
- `-r`: The repo_id to save on HF Hub. This is the same as `hf_repository` in the enjoy script and must be in the form `<hf_username>/<hf_repo_name>`
-- `-d`: The full path to your experiment directory to upload
\ No newline at end of file
+- `-d`: The full path to your experiment directory to upload
+
+The optional arguments of `--train_script` and `--enjoy_script` can also be used. See the above section for more details
\ No newline at end of file
diff --git a/sample_factory/cfg/arguments.py b/sample_factory/cfg/arguments.py
index 820efce6..f736342d 100644
--- a/sample_factory/cfg/arguments.py
+++ b/sample_factory/cfg/arguments.py
@@ -18,7 +18,7 @@ from sample_factory.cfg.cfg import (
)
from sample_factory.utils.attr_dict import AttrDict
from sample_factory.utils.typing import Config
-from sample_factory.utils.utils import cfg_file, cfg_file_old, get_git_commit_hash, get_top_level_script, log
+from sample_factory.utils.utils import cfg_file, cfg_file_old, get_git_commit_hash, log
def parse_sf_args(
@@ -91,7 +91,6 @@ def postprocess_args(args, argv, parser) -> argparse.Namespace:
args.cli_args = vars(cli_args)
args.git_hash, args.git_repo_name = get_git_commit_hash()
- args.train_script = get_top_level_script()
return args
diff --git a/sample_factory/cfg/cfg.py b/sample_factory/cfg/cfg.py
index 43393da1..360e6895 100644
--- a/sample_factory/cfg/cfg.py
+++ b/sample_factory/cfg/cfg.py
@@ -675,6 +675,19 @@ def add_eval_args(parser):
help="False to sample from action distributions at test time. True to just use the argmax.",
)
+ parser.add_argument(
+ "--train_script",
+ default=None,
+ type=str,
+ help="Module name used to run training script. Used to generate HF model card",
+ )
+ parser.add_argument(
+ "--enjoy_script",
+ default=None,
+ type=str,
+ help="Module name used to run training script. Used to generate HF model card",
+ )
+
def add_wandb_args(p: ArgumentParser):
"""Weights and Biases experiment monitoring."""
diff --git a/sample_factory/enjoy.py b/sample_factory/enjoy.py
index 341b537b..b620c532 100644
--- a/sample_factory/enjoy.py
+++ b/sample_factory/enjoy.py
@@ -21,7 +21,7 @@ from sample_factory.model.actor_critic import create_actor_critic
from sample_factory.model.model_utils import get_rnn_size
from sample_factory.utils.attr_dict import AttrDict
from sample_factory.utils.typing import Config, StatusCode
-from sample_factory.utils.utils import debug_log_every_n, experiment_dir, get_top_level_script, log
+from sample_factory.utils.utils import debug_log_every_n, experiment_dir, log
def visualize_policy_inputs(normalized_obs: Dict[str, Tensor]) -> None:
@@ -260,9 +260,8 @@ def enjoy(cfg: Config) -> Tuple[StatusCode, float]:
generate_replay_video(experiment_dir(cfg=cfg), video_frames, fps)
if cfg.push_to_hub:
- enjoy_name = get_top_level_script()
generate_model_card(
- experiment_dir(cfg=cfg), cfg.algo, cfg.env, cfg.hf_repository, reward_list, enjoy_name, cfg.train_script
+ experiment_dir(cfg=cfg), cfg.algo, cfg.env, cfg.hf_repository, reward_list, cfg.enjoy_script, cfg.train_script
)
push_to_hf(experiment_dir(cfg=cfg), cfg.hf_repository)
diff --git a/sample_factory/huggingface/huggingface_utils.py b/sample_factory/huggingface/huggingface_utils.py
index 90184da7..5b4a6b14 100644
--- a/sample_factory/huggingface/huggingface_utils.py
+++ b/sample_factory/huggingface/huggingface_utils.py
@@ -57,8 +57,10 @@ python -m sample_factory.huggingface.load_from_hub -r {repo_id}
```\n
"""
- if enjoy_name is not None:
- readme += f"""
+ if enjoy_name is None:
+ enjoy_name = "<path.to.enjoy.module>"
+
+ readme += f"""
## Using the model\n
To run the model after download, use the `enjoy` script corresponding to this environment:
```
@@ -67,17 +69,19 @@ python -m {enjoy_name} --algo={algo} --env={env} --train_dir=./train_dir --exper
\n
You can also upload models to the Hugging Face Hub using the same script with the `--push_to_hub` flag.
See https://www.samplefactory.dev/10-huggingface/huggingface/ for more details
- """
+ """
- if train_name is not None:
- readme += f"""
+ if train_name is None:
+ train_name = "<path.to.train.module>"
+
+ readme += f"""
## Training with this model\n
To continue training with this model, use the `train` script corresponding to this environment:
```
python -m {train_name} --algo={algo} --env={env} --train_dir=./train_dir --experiment={repo_name} --restart_behavior=resume --train_for_env_steps=10000000000
```\n
Note, you may have to adjust `--train_for_env_steps` to a suitably high number as the experiment will resume at the number of steps it concluded at.
- """
+ """
with open(readme_path, "w", encoding="utf-8") as f:
f.write(readme)
diff --git a/sample_factory/huggingface/push_to_hub.py b/sample_factory/huggingface/push_to_hub.py
index dbd5c382..d67806ad 100644
--- a/sample_factory/huggingface/push_to_hub.py
+++ b/sample_factory/huggingface/push_to_hub.py
@@ -16,6 +16,18 @@ def main():
type=str,
)
parser.add_argument("-d", "--experiment_dir", help="Path to your experiment directory", type=str)
+ parser.add_argument(
+ "--train_script",
+ default=None,
+ type=str,
+ help="Module name used to run training script. Used to generate HF model card",
+ )
+ parser.add_argument(
+ "--enjoy_script",
+ default=None,
+ type=str,
+ help="Module name used to run training script. Used to generate HF model card",
+ )
args = parser.parse_args()
cfg_file = os.path.join(args.experiment_dir, "config.json")
@@ -34,7 +46,7 @@ def main():
json_params = json.load(json_file)
cfg = AttrDict(json_params)
- generate_model_card(args.experiment_dir, cfg.algo, cfg.env, args.hf_repository)
+ generate_model_card(args.experiment_dir, cfg.algo, cfg.env, args.hf_repository, enjoy_name=args.enjoy_script, train_name=args.train_script)
push_to_hf(args.experiment_dir, args.hf_repository)
diff --git a/sample_factory/utils/utils.py b/sample_factory/utils/utils.py
index 99db3c10..fcd335c5 100644
--- a/sample_factory/utils/utils.py
+++ b/sample_factory/utils/utils.py
@@ -493,5 +493,5 @@ def debug_log_every_n(n, msg, *args, **kwargs):
log_every_n(n, logging.DEBUG, msg, *args, **kwargs)
-def get_top_level_script():
- return argv[0].split("sample-factory/")[-1][:-3].replace("/", ".")
+# def get_top_level_script():
+# return argv[0].split("sample-factory/")[-1][:-3].replace("/", ".")
|