{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "toc_visible": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU", "gpuClass": "standard", "widgets": { "application/vnd.jupyter.widget-state+json": { "17e75f6149c14f8d9619f319bf9ee553": { "model_module": "@jupyter-widgets/output", "model_name": "OutputModel", "model_module_version": "1.0.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/output", "_model_module_version": "1.0.0", "_model_name": "OutputModel", "_view_count": null, "_view_module": "@jupyter-widgets/output", "_view_module_version": "1.0.0", "_view_name": "OutputView", "layout": "IPY_MODEL_cbb6deb9baf742e8a6765d0891f193c7", "msg_id": "", "outputs": [ { "output_type": "display_data", "data": { "text/plain": "\u001b[35m 93%\u001b[0m \u001b[38;2;249;38;114m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[38;2;249;38;114m╸\u001b[0m\u001b[38;5;237m━━━━\u001b[0m \u001b[32m9,314/10,000 \u001b[0m [ \u001b[33m0:00:02\u001b[0m < \u001b[36m0:00:01\u001b[0m , \u001b[31m3,340 it/s\u001b[0m ]\n", "text/html": "
93% ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸━━━━ 9,314/10,000 [ 0:00:02 < 0:00:01 , 3,340 it/s ]\n\n" }, "metadata": {} } ] } }, "cbb6deb9baf742e8a6765d0891f193c7": { "model_module": "@jupyter-widgets/base", "model_name": "LayoutModel", "model_module_version": "1.2.0", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "6314513b30fd4f11ad8b59b0bcdee8d8": { "model_module": "@jupyter-widgets/controls", "model_name": "VBoxModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "VBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "VBoxView", "box_style": "", "children": [ "IPY_MODEL_1966d341666441089910ba16f7ac169c", "IPY_MODEL_965ba6c8ce3d4c5bbe14c9a7192a221c", "IPY_MODEL_6f5377c55bfe4428805ed94034e68d6b", "IPY_MODEL_2d986d4e45d44a2b858390514b770a9d", "IPY_MODEL_3c532f73534240b19c47e2b943418e93" ], "layout": "IPY_MODEL_ac189f596cb343bd9a6ffcf1b158f56b" } }, "1966d341666441089910ba16f7ac169c": { "model_module": "@jupyter-widgets/controls", "model_name": "HTMLModel", "model_module_version": "1.5.0", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_7c2c6b532db14e9bafcdb0beeb1c33bc", "placeholder": "", "style": "IPY_MODEL_cd4b78c526cf4df88744348e711754da", "value": "
\n", "\n" ] }, "metadata": {} }, { "output_type": "stream", "name": "stdout", "text": [ "mean_reward:-147.89 +/- 27.79\n" ] } ] }, { "cell_type": "code", "source": [ "# Enjoy trained agent\n", "vec_env = model.get_env()\n", "obs = vec_env.reset()\n", "for i in range(1000):\n", " action, _states = model.predict(obs, deterministic=True)\n", " obs, rewards, dones, info = vec_env.step(action)\n", " vec_env.render()" ], "metadata": { "id": "rpqfDLuFJnnO" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv\n", "\n", "def record_video(env_id, model, video_length=500, prefix='', video_folder='videos/'):\n", " \"\"\"\n", " :param env_id: (str)\n", " :param model: (RL model)\n", " :param video_length: (int)\n", " :param prefix: (str)\n", " :param video_folder: (str)\n", " \"\"\"\n", " eval_env = model.get_env()\n", " # Start the video at step=0 and record 500 steps\n", " eval_env = VecVideoRecorder(eval_env, video_folder=video_folder,\n", " record_video_trigger=lambda step: step == 0, video_length=video_length,\n", " name_prefix=prefix)\n", "\n", " obs = eval_env.reset()\n", " for _ in range(video_length):\n", " action, _ = model.predict(obs)\n", " obs, _, _, _ = eval_env.step(action)\n", "\n", " # Close the video recorder\n", " eval_env.close()\n", "\n", "import base64\n", "from pathlib import Path\n", "\n", "from IPython import display as ipythondisplay\n", "\n", "def show_videos(video_path='', prefix=''):\n", " \"\"\"\n", " Taken from https://github.com/eleurent/highway-env\n", "\n", " :param video_path: (str) Path to the folder containing videos\n", " :param prefix: (str) Filter the video, showing only the only starting with this prefix\n", " \"\"\"\n", " html = []\n", " for mp4 in Path(video_path).glob(\"{}*.mp4\".format(prefix)):\n", " video_b64 = base64.b64encode(mp4.read_bytes())\n", " html.append(''''''.format(mp4, video_b64.decode('ascii')))\n", " ipythondisplay.display(ipythondisplay.HTML(data=\"