diff --git a/unit-4/main.py b/unit-4/main.py index 347c250..834b615 100644 --- a/unit-4/main.py +++ b/unit-4/main.py @@ -69,7 +69,7 @@ class CartpolePolicy(nn.Module): class PixelcopterPolicy(nn.Module): def __init__(self, s_size, a_size, h_size, device): - super(Policy, self).__init__() + super(PixelcopterPolicy, self).__init__() self.fc1 = nn.Linear(s_size, h_size) self.fc2 = nn.Linear(h_size, h_size * 2) self.fc3 = nn.Linear(h_size * 2, a_size) @@ -170,8 +170,29 @@ def reinforce(policy, env, optimizer, n_training_episodes, max_t, gamma, print_e return scores - -def push_to_hub(repo_id, model, hyperparameters, eval_env, video_fps=30): +def record_video(env, policy, out_directory, fps=30): + """ + Generate a replay video of the agent + :param env + :param Qtable: Qtable of our agent + :param out_directory + :param fps: how many frame per seconds (with taxi-v3 and frozenlake-v1 we use 1) + """ + images = [] + done = False + state = env.reset() + img = env.render(mode="rgb_array") + images.append(img) + while not done: + # Take the action (index) that have the maximum expected future reward given that state + action, _ = policy.act(state) + state, reward, done, info = env.step(action) # We directly put next_state = state for recording logic + img = env.render(mode="rgb_array") + images.append(img) + imageio.mimsave(out_directory, [np.array(img) for i, img in enumerate(images)], fps=fps) + + +def push_to_hub(repo_id, model, hparams, eval_env, video_fps=30): """ Evaluate, Generate a video and Upload a model to Hugging Face Hub. This method does the complete pipeline: @@ -182,7 +203,7 @@ def push_to_hub(repo_id, model, hyperparameters, eval_env, video_fps=30): :param repo_id: repo_id: id of the model repository from the Hugging Face Hub :param model: the pytorch model we want to save - :param hyperparameters: training hyperparameters + :param hparams: training hparams :param eval_env: evaluation environment :param video_fps: how many frame per seconds to record our video replay """ @@ -202,15 +223,15 @@ def push_to_hub(repo_id, model, hyperparameters, eval_env, video_fps=30): # Step 2: Save the model torch.save(model, local_directory / "model.pt") - # Step 3: Save the hyperparameters to JSON - with open(local_directory / "hyperparameters.json", "w") as outfile: - json.dump(hyperparameters, outfile) + # Step 3: Save the hparams to JSON + with open(local_directory / "hparams.json", "w") as outfile: + json.dump(hparams, outfile) # Step 4: Evaluate the model and build JSON mean_reward, std_reward = evaluate_agent( eval_env, - hyperparameters["max_t"], - hyperparameters["n_evaluation_episodes"], + hparams["max_t"], + hparams["n_evaluation_episodes"], model, ) # Get datetime @@ -218,9 +239,9 @@ def push_to_hub(repo_id, model, hyperparameters, eval_env, video_fps=30): eval_form_datetime = eval_datetime.isoformat() evaluate_data = { - "env_id": hyperparameters["env_id"], + "env_id": hparams["env_id"], "mean_reward": mean_reward, - "n_evaluation_episodes": hyperparameters["n_evaluation_episodes"], + "n_evaluation_episodes": hparams["n_evaluation_episodes"], "eval_datetime": eval_form_datetime, } @@ -229,7 +250,7 @@ def push_to_hub(repo_id, model, hyperparameters, eval_env, video_fps=30): json.dump(evaluate_data, outfile) # Step 5: Create the model card - env_name = hyperparameters["env_id"] + env_name = hparams["env_id"] metadata = {} metadata["tags"] = [ @@ -256,8 +277,8 @@ def push_to_hub(repo_id, model, hyperparameters, eval_env, video_fps=30): metadata = {**metadata, **eval} model_card = f""" - # **Reinforce** Agent playing **{env_id}** - This is a trained model of a **Reinforce** agent playing **{env_id}** . + # **Reinforce** Agent playing **{env_name}** + This is a trained model of a **Reinforce** agent playing **{env_name}** . To learn to use this model and train yours check Unit 4 of the Deep Reinforcement Learning Course: https://huggingface.co/deep-rl-course/unit4/introduction """ @@ -277,7 +298,7 @@ def push_to_hub(repo_id, model, hyperparameters, eval_env, video_fps=30): # Step 6: Record a video video_path = local_directory / "replay.mp4" - record_video(env, model, video_path, video_fps) + record_video(eval_env, model, video_path, video_fps) # Step 7. Push everything to the Hub api.upload_folder(