File size: 8,119 Bytes
4cffbfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
diff --git a/docs/10-huggingface/huggingface.md b/docs/10-huggingface/huggingface.md
index 8846da73..1f1fae6f 100644
--- a/docs/10-huggingface/huggingface.md
+++ b/docs/10-huggingface/huggingface.md
@@ -77,10 +77,16 @@ You can also save a video of the model during evaluation to upload to the hub wi
 
 - `--video_name`: The name of the video to save as. If `None`, will save to `replay.mp4` in your experiment directory
 
+Also, you can include information in the Hugging Face Hub model card for how to train and enjoy using this model. These parameters are optional:
+
+- `--train_script`: The module path for training this model
+
+- `--enjoy_script`: The module path for enjoying this model
+
 For example:
 
 ```
-python -m sf_examples.mujoco.enjoy_mujoco --algo=APPO --env=mujoco_ant --experiment=<repo_name> --train_dir=./train_dir --max_num_episodes=10 --push_to_hub --hf_repository=<username>/<hf_repo_name> --save_video --no_render
+python -m sf_examples.mujoco.enjoy_mujoco --algo=APPO --env=mujoco_ant --experiment=<repo_name> --train_dir=./train_dir --max_num_episodes=10 --push_to_hub --hf_repository=<username>/<hf_repo_name> --save_video --no_render --enjoy_script=sf_examples.mujoco.enjoy_mujoco --train_script=sf_examples.mujoco.train_mujoco
 ```
 
 #### Using the push_to_hub Script
@@ -95,4 +101,6 @@ The command line arguments are:
 
 - `-r`: The repo_id to save on HF Hub. This is the same as `hf_repository` in the enjoy script and must be in the form `<hf_username>/<hf_repo_name>`
 
-- `-d`: The full path to your experiment directory to upload
\ No newline at end of file
+- `-d`: The full path to your experiment directory to upload
+
+The optional arguments of `--train_script` and `--enjoy_script` can also be used. See the above section for more details
\ No newline at end of file
diff --git a/sample_factory/cfg/arguments.py b/sample_factory/cfg/arguments.py
index 820efce6..f736342d 100644
--- a/sample_factory/cfg/arguments.py
+++ b/sample_factory/cfg/arguments.py
@@ -18,7 +18,7 @@ from sample_factory.cfg.cfg import (
 )
 from sample_factory.utils.attr_dict import AttrDict
 from sample_factory.utils.typing import Config
-from sample_factory.utils.utils import cfg_file, cfg_file_old, get_git_commit_hash, get_top_level_script, log
+from sample_factory.utils.utils import cfg_file, cfg_file_old, get_git_commit_hash, log
 
 
 def parse_sf_args(
@@ -91,7 +91,6 @@ def postprocess_args(args, argv, parser) -> argparse.Namespace:
 
     args.cli_args = vars(cli_args)
     args.git_hash, args.git_repo_name = get_git_commit_hash()
-    args.train_script = get_top_level_script()
     return args
 
 
diff --git a/sample_factory/cfg/cfg.py b/sample_factory/cfg/cfg.py
index 43393da1..360e6895 100644
--- a/sample_factory/cfg/cfg.py
+++ b/sample_factory/cfg/cfg.py
@@ -675,6 +675,19 @@ def add_eval_args(parser):
         help="False to sample from action distributions at test time. True to just use the argmax.",
     )
 
+    parser.add_argument(
+        "--train_script",
+        default=None,
+        type=str,
+        help="Module name used to run training script. Used to generate HF model card",
+    )
+    parser.add_argument(
+        "--enjoy_script",
+        default=None,
+        type=str,
+        help="Module name used to run training script. Used to generate HF model card",
+    )
+
 
 def add_wandb_args(p: ArgumentParser):
     """Weights and Biases experiment monitoring."""
diff --git a/sample_factory/enjoy.py b/sample_factory/enjoy.py
index 341b537b..b620c532 100644
--- a/sample_factory/enjoy.py
+++ b/sample_factory/enjoy.py
@@ -21,7 +21,7 @@ from sample_factory.model.actor_critic import create_actor_critic
 from sample_factory.model.model_utils import get_rnn_size
 from sample_factory.utils.attr_dict import AttrDict
 from sample_factory.utils.typing import Config, StatusCode
-from sample_factory.utils.utils import debug_log_every_n, experiment_dir, get_top_level_script, log
+from sample_factory.utils.utils import debug_log_every_n, experiment_dir, log
 
 
 def visualize_policy_inputs(normalized_obs: Dict[str, Tensor]) -> None:
@@ -260,9 +260,8 @@ def enjoy(cfg: Config) -> Tuple[StatusCode, float]:
         generate_replay_video(experiment_dir(cfg=cfg), video_frames, fps)
 
     if cfg.push_to_hub:
-        enjoy_name = get_top_level_script()
         generate_model_card(
-            experiment_dir(cfg=cfg), cfg.algo, cfg.env, cfg.hf_repository, reward_list, enjoy_name, cfg.train_script
+            experiment_dir(cfg=cfg), cfg.algo, cfg.env, cfg.hf_repository, reward_list, cfg.enjoy_script, cfg.train_script
         )
         push_to_hf(experiment_dir(cfg=cfg), cfg.hf_repository)
 
diff --git a/sample_factory/huggingface/huggingface_utils.py b/sample_factory/huggingface/huggingface_utils.py
index 90184da7..5b4a6b14 100644
--- a/sample_factory/huggingface/huggingface_utils.py
+++ b/sample_factory/huggingface/huggingface_utils.py
@@ -57,8 +57,10 @@ python -m sample_factory.huggingface.load_from_hub -r {repo_id}
 ```\n
     """
 
-    if enjoy_name is not None:
-        readme += f"""
+    if enjoy_name is None:
+        enjoy_name = "<path.to.enjoy.module>"
+        
+    readme += f"""
 ## Using the model\n
 To run the model after download, use the `enjoy` script corresponding to this environment:
 ```
@@ -67,17 +69,19 @@ python -m {enjoy_name} --algo={algo} --env={env} --train_dir=./train_dir --exper
 \n
 You can also upload models to the Hugging Face Hub using the same script with the `--push_to_hub` flag.
 See https://www.samplefactory.dev/10-huggingface/huggingface/ for more details
-        """
+    """
 
-    if train_name is not None:
-        readme += f"""
+    if train_name is None:
+        train_name = "<path.to.train.module>"
+        
+    readme += f"""
 ## Training with this model\n
 To continue training with this model, use the `train` script corresponding to this environment:
 ```
 python -m {train_name} --algo={algo} --env={env} --train_dir=./train_dir --experiment={repo_name} --restart_behavior=resume --train_for_env_steps=10000000000
 ```\n
 Note, you may have to adjust `--train_for_env_steps` to a suitably high number as the experiment will resume at the number of steps it concluded at.
-        """
+    """
 
     with open(readme_path, "w", encoding="utf-8") as f:
         f.write(readme)
diff --git a/sample_factory/huggingface/push_to_hub.py b/sample_factory/huggingface/push_to_hub.py
index dbd5c382..d67806ad 100644
--- a/sample_factory/huggingface/push_to_hub.py
+++ b/sample_factory/huggingface/push_to_hub.py
@@ -16,6 +16,18 @@ def main():
         type=str,
     )
     parser.add_argument("-d", "--experiment_dir", help="Path to your experiment directory", type=str)
+    parser.add_argument(
+        "--train_script",
+        default=None,
+        type=str,
+        help="Module name used to run training script. Used to generate HF model card",
+    )
+    parser.add_argument(
+        "--enjoy_script",
+        default=None,
+        type=str,
+        help="Module name used to run training script. Used to generate HF model card",
+    )
     args = parser.parse_args()
 
     cfg_file = os.path.join(args.experiment_dir, "config.json")
@@ -34,7 +46,7 @@ def main():
         json_params = json.load(json_file)
         cfg = AttrDict(json_params)
 
-    generate_model_card(args.experiment_dir, cfg.algo, cfg.env, args.hf_repository)
+    generate_model_card(args.experiment_dir, cfg.algo, cfg.env, args.hf_repository, enjoy_name=args.enjoy_script, train_name=args.train_script)
     push_to_hf(args.experiment_dir, args.hf_repository)
 
 
diff --git a/sample_factory/utils/utils.py b/sample_factory/utils/utils.py
index 99db3c10..fcd335c5 100644
--- a/sample_factory/utils/utils.py
+++ b/sample_factory/utils/utils.py
@@ -493,5 +493,5 @@ def debug_log_every_n(n, msg, *args, **kwargs):
     log_every_n(n, logging.DEBUG, msg, *args, **kwargs)
 
 
-def get_top_level_script():
-    return argv[0].split("sample-factory/")[-1][:-3].replace("/", ".")
+# def get_top_level_script():
+#     return argv[0].split("sample-factory/")[-1][:-3].replace("/", ".")