jostyposty's picture
fix: rename video to replay.mp4
f88882e
raw
history blame
No virus
1.62 kB
from huggingface_sb3.push_to_hub import generate_metadata
from huggingface_hub.repocard import metadata_save
from hf_helpers.gym_video import generate_video, load_ppo_model_for_video
from hf_helpers.hf_sb3 import generate_config_json, generate_results_json
from hf_helpers.sb3_eval import eval_model_with_seed
readme_path = "README.md"
env_id = "LunarLander-v2"
main_model_fp = "ppo-LunarLander-v2_010_000_000_hf_defaults.zip"
other_models = [
"ppo-LunarLander-v2_001_000_000_hf_defaults.zip",
"ppo-LunarLander-v2_010_000_000_sb3_defaults.zip",
"ppo-LunarLander-v2_123_456_789_hf_defaults.zip",
]
# 1. Evaluate model
best_seed = 902
best_n_envs = 8
n_eval_episodes = 10
result, mean_reward, std_reward = eval_model_with_seed(
main_model_fp,
env_id,
seed=best_seed,
n_eval_episodes=n_eval_episodes,
n_envs=best_n_envs,
)
# 2. Create config.json
generate_config_json(main_model_fp, "config.json")
# Also create config files for the other models
for model_fp in other_models:
generate_config_json(model_fp, f"config-{model_fp.replace('.zip', '')}.json")
# 3. Create results.json
generate_results_json("results.json", mean_reward, std_reward, n_eval_episodes, True)
# 4. Generate video
model_for_video = load_ppo_model_for_video(main_model_fp, env_id)
generate_video(model_for_video, "replay.mp4", video_length_in_episodes=5)
# 5. Generate model card
metadata = generate_metadata(
model_name=main_model_fp.replace(".zip", ""),
env_id=env_id,
mean_reward=mean_reward,
std_reward=std_reward,
)
metadata["license"] = "mit"
metadata_save(readme_path, metadata)