File size: 3,520 Bytes
2d9a728
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Visualize dataset\n",
    "\n",
    "Utilities to visualize episodes from a dataset. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "from matplotlib import pyplot as plt\n",
    "from matplotlib import animation\n",
    "import pathlib\n",
    "from IPython.display import Video\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "dataset_path = pathlib.Path(os.path.abspath('')).parent / 'data/stickman_example'\n",
    "\n",
    "directory = dataset_path.expanduser()\n",
    "filenames = sorted(directory.glob('*.npz'))\n",
    "if len(filenames) == 0:\n",
    "    raise ValueError(\"Empty directory (or no episodes)\")\n",
    "\n",
    "try:\n",
    "    filenames_dict = { int(str(f).replace(str(dataset_path), \"\").split(\"-\")[0][1:]) : f for f in filenames}\n",
    "except Exception as e:\n",
    "    print(\"Error:\", e)\n",
    "\n",
    "print(directory)\n",
    "print(len(filenames))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ep_num = next(iter(filenames_dict))\n",
    "\n",
    "filename = filenames_dict[ep_num]\n",
    "with filename.open('rb') as f:\n",
    "    episode = np.load(f)\n",
    "    episode = {k: episode[k] for k in episode.keys()}\n",
    "\n",
    "# Show reward on top with red/green bar\n",
    "pix_rew_max = np.round(episode['reward'] / 2 * 64)\n",
    "for ob, pix_n in zip(episode['observation'], pix_rew_max):\n",
    "    if pix_n < 0:\n",
    "        pix_n = abs(pix_n)\n",
    "        ob[:, 0, :int(pix_n+1)] = np.array([255,0,0]).reshape(3,1)\n",
    "    else:\n",
    "        ob[:, 0, :int(pix_n+1)] = np.array([0,255,0]).reshape(3,1)\n",
    "\n",
    "# # np array with shape (frames, height, width, channels)\n",
    "video = np.transpose(episode['observation'], axes=[0,2,3,1])\n",
    "\n",
    "fig = plt.figure(frameon=False)\n",
    "ax = plt.Axes(fig, [0., 0., 1., 1.])\n",
    "ax.set_axis_off()\n",
    "fig.add_axes(ax)\n",
    "fig.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)\n",
    "fig.set_size_inches(2,2)\n",
    "im = ax.imshow(video[0,:,:,:])\n",
    "plt.close() # this is required to not display the generated image\n",
    "\n",
    "def init():\n",
    "    im.set_data(video[0,:,:,:])\n",
    "\n",
    "def animate(i):\n",
    "    im.set_data(video[i,:,:,:])\n",
    "    return im\n",
    "\n",
    "print('Episode reward', np.sum(episode['reward']))\n",
    "anim = animation.FuncAnimation(fig, animate, init_func=init, frames=video.shape[0],interval=45)\n",
    "file_path = str(pathlib.Path(os.path.abspath('')) / 'videos/temp.mp4')\n",
    "anim.save(file_path)\n",
    "print('Video file', file_path)\n",
    "Video(file_path)"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "3d597f4c481aa0f25dceb95d2a0067e73c0966dcbd003d741d821a7208527ecf"
  },
  "kernelspec": {
   "display_name": "Python 3.8.10 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}