{ "cells": [ { "cell_type": "code", "execution_count": 1, "source": [ "import abc\n", "import random\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import tensorflow as tf\n", "from einops import repeat, rearrange\n", "tf.config.experimental.set_visible_devices([], 'GPU')\n", "\n", "# uncomment this to enable jax gpu preallocation, might lead to memory issues\n", "\n", "import os\n", "os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"false\"" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "# Gecko" ], "metadata": {} }, { "cell_type": "code", "execution_count": 2, "source": [ "from jax_nca.dataset import ImageDataset" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 3, "source": [ "dataset = ImageDataset(emoji='🦎', img_size=64)" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 4, "source": [ "dataset.visualize()" ], "outputs": [ { "output_type": "display_data", "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD7CAYAAACscuKmAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAcMElEQVR4nO3de3RV1Z0H8O8vT8IzQQIEgjwERQYFnCugRUdQLFVb1DrWR0dWS8u00zq22qq0s2a0ra22tdVZa8aWVa1MfeCjtVp0oYhanwWDKAKBgrwDgfAIEB4Jyf3NH/dw9t7H3OSS3HsT2N/PWix+9+5z79nJze+evc/ZZ29RVRDRyS+noytARNnBZCfyBJOdyBNMdiJPMNmJPMFkJ/JEu5JdRKaJyBoRWScid6arUkSUftLW6+wikgvg7wCmAtgK4H0A16vqqvRVj4jSJa8drx0PYJ2qrgcAEZkHYDqApMnep08fHTJkSDt2SUQt2bhxI3bt2iXNlbUn2QcC2GI93gpgQksvGDJkCCoqKtqxSyJqSSwWS1qW8RN0IjJLRCpEpKKmpibTuyOiJNqT7FUABlmPy4PnHKo6R1VjqhorLS1tx+6IqD3ak+zvAxghIkNFpADAdQBeSE+1iCjd2txnV9VGEfk2gJcB5AJ4RFVXpq1mRJRW7TlBB1V9CcBLaaoLEWUQR9AReYLJTuQJJjuRJ5jsRJ5gshN5gslO5AkmO5EnmOxEnmCyE3mCyU7kCSY7kSeY7ESeYLITeYLJTuQJJjuRJ5jsRJ5gshN5gslO5AkmO5EnmOxEnmCyE3mCyU7kCSY7kSeY7ESeYLITeaLVZBeRR0Rkp4issJ7rLSILRWRt8H9JZqtJRO2VypH9UQDTIs/dCWCRqo4AsCh4TESdWKvJrqpvAtgTeXo6gLlBPBfAlemtFhGlW1v77P1UdXsQVwPol6b6EFGGtPsEnaoqAE1WLiKzRKRCRCpqamrauzsiaqO2JvsOESkDgOD/nck2VNU5qhpT1VhpaWkbd0dE7dXWZH8BwIwgngHg+fRUh4gyJZVLb08CeA/AGSKyVURmArgXwFQRWQvgkuAxEXViea1toKrXJym6OM11IaIMajXZKXMS5zYNETnu16X6GiIOlyXyBJOdyBNsxmeZInkTvLbhQBjfv+65MB7Spa+z3cxhZvRyPB53yuz3ZBOfbDyyE3mCyU7kCSY7kSfYZ88wjdw2ELcum8W1ySmb/s5dYfy2rjMFje57LK1dG8b/e87NTlmTmj58LthnJ4NHdiJPMNmJPMFmfIbFI6PkcsV8v95T+YRT9vZR0zwv2G+a43Fx3+O3+koYD17lXpa7Y9SXwthp0gu/133HvwAiTzDZiTzBZnwGxJG8+Vy5b1MY/2zNU05ZTtycnW+w3gPuIDnk1plm/Q82/sEpO6tkWBhfVnZuGNtN+ubqZbNH5UmOOaMf7ZLk2KP1eOa/0+ORncgTTHYiTzDZiTzBPnsGOF3bSFf2/rV/CuOGXu53rd0XR4PVx468R9x6mextcMq+sezBMF7e5zdh3Cuvm1tH++67yA5ycpo/BuTyLroTGo/sRJ5gshN5gs34NIje7GJf1qprPOSULaipCOOc1XvdN2qy3meYaXZrQa77/m/uCOOuc9Y4ZTtvGB7Gj4x+NYxvPfVKZ7uj8UZTj8hluE0Hq83753UJ41X7Nznbjeg+MIzLuyZfE4CX5ToHHtmJPMFkJ/IEk53IE+yzp4HGI/O/W0NMNx91V7uuXr81jItuq3DKUGT65ocfMENddUh3Z7Oc3fVmXwcOu/veaR7P37EkjL87aLqzXX5O8o8+9qqZEKOkwOx7Y3y3s92t1nv+YszXnbJG65xAXgv7ouxJZfmnQSLyuoisEpGVInJL8HxvEVkoImuD/0syX10iaqtUmvGNAG5T1VEAJgL4loiMAnAngEWqOgLAouAxEXVSqaz1th3A9iA+ICKVAAYCmA7gomCzuQDeAHBHRmrZyWkLV5YOHKlzn+hZEIbx4T3csi6mGa9drY+m7qizWcNnB5j3GNrTKYuPMO+5YbfpMkSXmtrVsC+MH9v8mlN2dvHQMF5Xty2MRxUMdLb7/aaXw7hXgTtC7z/OvCGMOYlG53Bcv3kRGQJgHIDFAPoFXwQAUA2gX3qrRkTplHKyi0h3AH8E8B1V3W+XaeKwoUleN0tEKkSkoqampl2VJaK2SynZRSQfiUR/XFWP3cmxQ0TKgvIyADube62qzlHVmKrGSkuTj7Iiosxqtc8uiQXDHgZQqaq/sopeADADwL3B/89npIYnAIm2aaw+fN+8Xk5Rbs/CMD5yz1j3ddb7aL75Hs6PzBufn2fK6se6F0Hsy4CHG8wluuidbO/trgzj7+12J74cfKhrGO8+Yhpx1Udrne3O7DkojMf2GuaUuWvagTqBVC6AfgbAvwD4WEQ+DJ77ARJJ/rSIzASwCcC1GakhEaVFKmfj38an7qgOXZze6hBRpnBoUxpEl0a2L3OVd3fPUwwv7B/Gq3O2O2WoN5eoiqw74Eq6uXe9dc81+6s/7E4kudmK8/PMx9tgjWgDgItLx4bx443fdsp+uv9J87oiU4+BuX2c7V6bdG8Ylxa63Ql70spkk2G0xn4P+/JmtNtkl/HSXnL8zRB5gslO5Ak249Mg2oxvbDJN5vxc91f81WHTwviO9XPd97Ga5DmF5nt4/CF3tdch1u7WlBQ6ZQ0NZtsSMaPpCiI3o9jN3etOvcgpi/U6LYxHPHpJGA8eOsDZzm66f2pe+jY03aOTgCRt/vPsfpvwyE7kCSY7kSeY7ESeYJ89A3JzzaUy+/IRAHxz2BVh/OiGV5yyyqKqMD5idUyv2HjA2e6CI+acwENTy5wysfq9q4/UhvGGg+5lvqHd3NfZPqj6yLzfQTMZRuUnS53t9jSY0XW9C9y771qal97W0iW6dXXm97G8dkMYnxap+6G4GSk48ZQzk+7L94kveWQn8gSTncgTbMZngN1cjF5Osudhf3LCbKfswre/H8b71TSfHzrdbSJX9jIf26hC9/u6MC8/jPfkHwnjS968zdnum8OuDuMNG5c7ZXOWmEuC2sX8LAVwR/LlWKP8opNjpNpitpvuR5vcUX73VJqRfP+3640w7llf4Gxn/753feGZpO+fatfiZMUjO5EnmOxEnmCyE3mCffYMi15OaoibJZZHWxM7AsCi881dZJ9fcncYV6m7JtxWq398eo7b95y2zkxwWWpNZDGvr9unvmuD6ZcfPOwu+5xTbi5txfLNENk/XvaAs11xkZmYozHS386zhgk3NJr3j64r997uVWF8zcf3OWVXFY8P4/EFZghvY747fLj2qPmZJ7/5fafsxUk/DuNueUVhHD2X4kMfnkd2Ik8w2Yk8wWZ8lhXkFCQtG97DNJnfsCaGuOSt7znbVTUeDOMdR9wm7ZXLasO4cKiZy339ae5HXVhv5qJfIflO2b4SM0lFVRfT9J1T/bqz3e3dvhjGPQvdJarsu+AK8pL/zAO7m8k8+nU9xSn7aJ8ZNVd1eFcY16vbZbDvsNvd4I42jNyM5zUe2Yk8wWQn8gSb8RkWndThR6seC+OjkXnhHtlollPadrkZPXZF/0nOdn+pNTfQbDjinlV+6YtmiaaiAvNdHotMR92zqykrzHXLVh0w9dpeb1ahvWfrs249qv8Wxo+OvdUpG9d7RBjPX7sojF/b/K6zXf8upul+a9cLnLL795nfR1VTbRh3EffP9qHRZg69mwZPBTWPR3YiTzDZiTzBZCfyBPvsaRAdjWWLztf+cnVFGO+qd9bHRK88c6msrtHc9XbT4CnOdg9vXRDGy/Pc9++fby5zndndfLynr9rnbHf5e+ZS1iNXlztlxVZfv36HmRhidYO7r+XdtoTxxe/8wCl7aPS/hvGsZ02fen9RPVzWyLVGd2lqOcWcf5DhJj6vabiznd1Pb2ElLu+1emQXkS4iskREPhKRlSJyd/D8UBFZLCLrROQpEUl+MZWIOlwqzfh6AFNUdQyAsQCmichEAPcB+LWqDgewF8DMjNWSiNotlbXeFMCxOw3yg38KYAqAG4Ln5wK4C8BD6a9i52evnAq4N79c9c5dTlnPQtNUP2A11QFgnZr169/dY1ZZndY/5mx3Vr5ZPXUttjhlfazLZt27mMkmSvu488uvnWAueQ3s55YVWstQffkvZu6690rckXa3jjfvsavuoFN2w/u/COPcQlOP3Dr3UmS/7v3CuLx3f6esoJ8ZUdjTWjbr9qFXOdvFYS0TFWnHczkoI9X12XODFVx3AlgI4BMAtarhuMWtAAYmeTkRdQIpJbuqNqnqWADlAMYDGJnqDkRklohUiEhFTU1N6y8goow4rjaOqtYCeB3AeQCKRcKhTOUAqpK8Zo6qxlQ1Vlpa2twmRJQFrfbZRaQUwFFVrRWRIgBTkTg59zqAawDMAzADwPOZrOiJ6nNl5zqP5255NYxX17n9bRSbvvOyfevDOC/SH16x5s0wrh/oXjZblm++v/N3mQkn0bfI2W5Eb9P/HvXmLqds05jiMF45ydwBFy9wjw2f72n+fJZFDhsf1ZsnastM31urNjjbPXy5mbBi2ojJaDdea0sqlevsZQDmikguEi2Bp1V1voisAjBPRH4CYBmAhzNYTyJqp1TOxi8HMK6Z59cj0X8nohMAR9ClgbbQdLywz1nO41srfxfGk/uc7ZStOrg1jH9e+UQYN65e42x3uNFM0CBb3UtZOwcPCeP3rSWgD+9155nru9OMZBv5gTvH3baRZqnnquFmUoq+mw45251jXdqLbXHLlteZS4C/79c1jHfVu8s+jzzFzC0XnXu+KW4m5hDrcqZELq8lXdqZHPwtEXmCyU7kCTbj0yCnhVPA0RVHn/lHs+TT1QPdSSmerno7jK97644w1gb3hhl0M7chPDDhu05R/0FnmPdYauaxO1Tn3sRypJs5G7/iKnc8VFmxKRuz0VwJOPelbc52L88cFsaj9rk3sQytNd2G5UNNV+CD/sXOdq/UrgjjWb1PdcrsprszEo5n3NuER3YiTzDZiTzBZCfyBPvsaSCSvBPZs6Cb89jup8fj7mWza62yPxeb+OUSdwTdjyeayR3/7dybku67e44Zkfe15f/tlP0t11y++6TRrcfIbeZuvO3WBBjV17p96i555lixItbbKWuyfiVnW29fF3cvAf680kzA+aVB/+SU9cw3l+x8X245HXhkJ/IEk53IE2zGZ1h0VFjcao7m5CRvjv7h0p+G8e4L3JVJ+3azbk6JdAXs0XyXDZgQxkuLH3S2++1mMyf7w+sXOGXvHjYj6jYfNaPYNvfIdbabWG/Kyru6f0o9rUtvVy+vDeP603o621X3MCPvntn0V6fsa8MvC+NGay6/vBz+2bYFj+xEnmCyE3mCyU7kCXZ+Mix6WS43xctGuTmmf2z30QH3bjB7uyi7n1vW1X2P6d1Hh/G8rY84ZVsbd4fxJ73NpJIH1J0A44jVZ5+U5x43xlqTVg6sNpNonDbC7bOX55vfxwvb3nDK7D57Do9L7cbfIJEnmOxEnmAzvpNyLtlFWv4tNd2TLUVV1+COwrvoiRvCeH+OWyYHTfO/qNaUHRrgznf3bqkZHZi3113WqUeJGb23//rBYdwUGa1Xtsdcolt6YK1Ttu2QmRtvgNUNif6MHFGXGh7ZiTzBZCfyBJvxnVRLN9e0+DqrSWuPNOte4H7Uv5hiJsf491fvdsrq88xZ9nNGm5tTfnf+bGe712vNElV3rf2NUzbgoHmPyVare9gOd8mrtdbIu7wit4n/Qe06835WMz4eGZWY28bflW94ZCfyBJOdyBNMdiJPsM9+AmppIoea+towvuNjs0hPrPh0Z7vivmbZ5x4DhjllDYf3hPG7cTPJ5NdXuf3yNy80yzJX7neXsnrryCthPK7WTEZ59sp97nbnmr54YaH7s6w5sCmMr8DEMFZ1+/bgsswpSfm3FCzbvExE5gePh4rIYhFZJyJPiUhBa+9BRB3neL4SbwFQaT2+D8CvVXU4gL0AZqazYkSUXik140WkHMDlAO4BcKskrgtNAXBsGNZcAHcBeCgDdaSIlprx260m+NzNC8P4rzuXO9ttPLTDPLBGuwFATo/+Zl9dzPHgnZqPne02HTTvce8Y97t+9CtvhPGSXmaUXO3kfs52DfXmZ8lrcpvnGw5uR7OaHyRIrUj1yP4AgNsBHPs0TgFQq6rHxlVuBTCwmdcRUSfRarKLyBUAdqrq0rbsQERmiUiFiFTU1NS05S2IKA1SObJ/BsAXRGQjgHlINN8fBFAsIse6AeUAqpp7sarOUdWYqsZKS0vTUGUiaotU1mefDWA2AIjIRQC+p6o3isgzAK5B4gtgBoDnM1dNstkTOUQnnDy72FxG+6+RXw7j9/a7yz7bff1t9budsnpr0ovvll8ZxjMHXepsN6jIfHlHl02e3n9yGC85ZCa3zDvo1teeRb4JruojtWiORPZln8OIDqW1l3f2fWnn9vz0dyBxsm4dEn34h1vZnog60HENqlHVNwC8EcTrAYxPf5WIKBM4gu4EF23S2l7ctjiMNxx1T45+c7CZ3+3pLe587fayS78c/fWk79/YZJr7OepeAvzKkM+G8bPvzg/jvHx3u3yrO3Eozy3bD/cOuWMkcunNvkPwU3fAWQ/tLo+PTXr/fmIiTzHZiTzBZvwJL/nyUk+f9x9hvHSPO7/bP3/0szCe2nOMU7bqwOYwXnvAXFE9rVuZs12OdQNK9Cz4WSVDzeu6mptw1h5161HUYF63NzKCrmvczH8XR/Im+J6G/WH84f4NTtnAQrO67Bk9BsFnPLITeYLJTuQJJjuRJ9hnP8F9es5083hwN3OHWb64c81/dcDUMD6/5Eyn7MXqJWFcXGDmho/2le2Ra/aSVIA7yu/GwWZf31jj9tnzjpjXHc11+/1ab+56O9xo5qXff/SQs13stZvDuLrQnQM/77B5z9lDrwnju0bf5Gznw2W5k/OnIqJPYbITeUJUszcTQCwW04qKiqztz3fZbJq2tCRTbUNdGI9c4E5yUdNkLpvlxt0uSZM1J93Cc+8J4wm9z3C2W2rNL790z9+dstsrzQq1cWv03nvn/dLZbnzvkWY7uJcAT6QVZGOxGCoqKpqdSP/E+SmIqF2Y7ESeYLITeYKX3k5idj892qe2h7dGbxSzT+PkWH3vltafi14CbLLmdi8u6B7G3zn9ame7H256LIzz6pwiNFm3t92/9tkwfnHST5ztLiw9q9kYAO5d81QY7847EsZL97qXAJ0+e2TYbk7uyXFMPDl+CiJqFZOdyBNsxnsi2sxucZnjNKyAbDf/7UuAt4yY7mz3xKbXwnhlkTtnaYE1UG5B07IwvnHxvc52Vw48P4z/su1vTtnuxgNhrLlmFGHfwuKkdW9pQpAT2cn5UxHRpzDZiTzBZjxlhH3m3h6lWZTbxdnumfPNBBuXvDPbKdvW3az4mlNnbpiZt+MtZzvncTwyIrTQarrHzVWBKX3HOZvZdcxpqYtzAuORncgTTHYiTzDZiTzBPjtlnD2SL3pHmT0J5Gvnu5fUbv7YrAC+sOFDUxCZe955ywL3+FUmvcL4sXHfD+MSa1QfELlDUE7OY2Cq67NvBHAAieW4GlU1JiK9ATwFYAiAjQCuVdW9makmEbXX8XyFTVbVsaoaCx7fCWCRqo4AsCh4TESdVHua8dMBXBTEc5FYA+6OdtaHTnLRiSDs5vOInuVO2YLPmAkr3qr5OIz/asUAUHPUXKL7h15DnbKr+k0I49Kikmb3C5y8887ZUv0JFcArIrJURGYFz/VT1WMzAlYD6Nf8S4moM0j1yD5JVatEpC+AhSKy2i5UVRWJLreXEHw5zAKAU089tV2VJaK2S+nIrqpVwf87ATyHxFLNO0SkDACC/3cmee0cVY2paqy0tDQ9tSai49bqkV1EugHIUdUDQXwpgB8BeAHADAD3Bv8/n8mK0snJuSwX6UdLjrnEdoE1KcUFkQkqUtXSenE+SKUZ3w/Ac8FY5zwAT6jqAhF5H8DTIjITwCYA12aumkTUXq0mu6quBzCmmed3A7g4E5UiovTjCDrqNDLdtD6R5n/PBL9/eiKPMNmJPMFkJ/IEk53IE0x2Ik8w2Yk8wWQn8gSTncgTTHYiTzDZiTzBZCfyBJOdyBNMdiJPMNmJPMFkJ/IEk53IE0x2Ik8w2Yk8wWQn8gSTncgTTHYiTzDZiTzBZCfyBJOdyBMpJbuIFIvIsyKyWkQqReQ8EektIgtFZG3wf0nr70REHSXVI/uDABao6kgkloKqBHAngEWqOgLAouAxEXVSrSa7iPQCcCGAhwFAVRtUtRbAdABzg83mArgyM1UkonRI5cg+FEANgN+LyDIR+V2wdHM/Vd0ebFONxGqvRNRJpZLseQDOAfCQqo4DcBCRJruqKgBt7sUiMktEKkSkoqampr31JaI2SiXZtwLYqqqLg8fPIpH8O0SkDACC/3c292JVnaOqMVWNlZaWpqPORNQGrSa7qlYD2CIiZwRPXQxgFYAXAMwInpsB4PmM1JCI0iLV9dlvBvC4iBQAWA/gK0h8UTwtIjMBbAJwbWaqSETpkFKyq+qHAGLNFF2c1toQUcZwBB2RJ5jsRJ5gshN5gslO5AkmO5EnmOxEnmCyE3lCEsPas7QzkRokBuD0AbAraztuXmeoA8B6RLEeruOtx2BVbXZcelaTPdypSIWqNjdIx6s6sB6sRzbrwWY8kSeY7ESe6Khkn9NB+7V1hjoArEcU6+FKWz06pM9ORNnHZjyRJ7Ka7CIyTUTWiMg6EcnabLQi8oiI7BSRFdZzWZ8KW0QGicjrIrJKRFaKyC0dURcR6SIiS0Tko6AedwfPDxWRxcHn81Qwf0HGiUhuML/h/I6qh4hsFJGPReRDEakInuuIv5GMTduetWQXkVwA/wPgcwBGAbheREZlafePApgWea4jpsJuBHCbqo4CMBHAt4LfQbbrUg9giqqOATAWwDQRmQjgPgC/VtXhAPYCmJnhehxzCxLTkx/TUfWYrKpjrUtdHfE3krlp21U1K/8AnAfgZevxbACzs7j/IQBWWI/XACgL4jIAa7JVF6sOzwOY2pF1AdAVwAcAJiAxeCOvuc8rg/svD/6ApwCYD0A6qB4bAfSJPJfVzwVALwAbEJxLS3c9stmMHwhgi/V4a/BcR+nQqbBFZAiAcQAWd0Rdgqbzh0hMFLoQwCcAalW1MdgkW5/PAwBuBxAPHp/SQfVQAK+IyFIRmRU8l+3PJaPTtvMEHVqeCjsTRKQ7gD8C+I6q7u+Iuqhqk6qOReLIOh7AyEzvM0pErgCwU1WXZnvfzZikqucg0c38lohcaBdm6XNp17TtrclmslcBGGQ9Lg+e6ygpTYWdbiKSj0SiP66qf+rIugCAJlb3eR2J5nKxiByblzAbn89nAHxBRDYCmIdEU/7BDqgHVLUq+H8ngOeQ+ALM9ufSrmnbW5PNZH8fwIjgTGsBgOuQmI66o2R9KmwRESSW0apU1V91VF1EpFREioO4CInzBpVIJP012aqHqs5W1XJVHYLE38NrqnpjtushIt1EpMexGMClAFYgy5+LZnra9kyf+IicaLgMwN+R6B/+MIv7fRLAdgBHkfj2nIlE33ARgLUAXgXQOwv1mIREE2w5gA+Df5dluy4AzgawLKjHCgD/GTw/DMASAOsAPAOgMIuf0UUA5ndEPYL9fRT8W3nsb7OD/kbGAqgIPps/AyhJVz04go7IEzxBR+QJJjuRJ5jsRJ5gshN5gslO5AkmO5EnmOxEnmCyE3ni/wETE8t04/X9pgAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" } } ], "metadata": {} }, { "cell_type": "code", "execution_count": 5, "source": [ "from jax_nca.nca import NCA" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "### NCA\n", "- num_hidden_channels = 16\n", "- num_target_channels = 3\n", "- cell_fire_rate = 1.0 (100% chance for cells to be updated)\n", "- alpha_living_threshold = 0.1 (threshold for cells to be alive)" ], "metadata": {} }, { "cell_type": "code", "execution_count": 6, "source": [ "nca = NCA(16, 3, trainable_perception=False, cell_fire_rate=1.0, alpha_living_threshold=0.1)" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 7, "source": [ "from jax_nca.trainer import EmojiTrainer" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 8, "source": [ "trainer = EmojiTrainer(dataset, nca, n_damage=0)" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 9, "source": [ "# trainer.train(100000, batch_size=8, seed=10, lr=2e-4, min_steps=64, max_steps=96)" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "#### Get current state from trainer" ], "metadata": {} }, { "cell_type": "code", "execution_count": 10, "source": [ "state = trainer.state" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 11, "source": [ "# save\n", "# nca.save(state.params, \"saved_params\")" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 12, "source": [ "params = nca.load(\"gecko_100_cell_fire_rate\")" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 13, "source": [ "import numpy as np\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "%matplotlib ipympl\n", "\n", "plt.style.use('ggplot')\n", "# Imports specifically so we can render outputs in Jupyter.\n", "from JSAnimation.IPython_display import display_animation\n", "from matplotlib import animation\n", "from IPython.display import display\n", "from celluloid import Camera\n", "from IPython.display import HTML\n", "import jax\n", "import jax.numpy as jnp\n", "\n", "def render_nca_steps(nca, params, shape = (64, 64), num_steps = 2):\n", " nca_seed = nca.create_seed(nca.num_hidden_channels, nca.num_target_channels, shape=shape, batch_size=1)\n", " rng = jax.random.PRNGKey(0)\n", " _, outputs = nca.multi_step(params, nca_seed, rng, num_steps=num_steps)\n", " stacked = jnp.squeeze(jnp.stack(outputs))\n", " rgbs = np.array(nca.to_rgb(stacked))\n", "\n", " fig = plt.figure(\"Animation\",figsize=(7,5))\n", " camera = Camera(fig)\n", " ax = fig.add_subplot(111)\n", " frames = []\n", " for r in rgbs:\n", " frame = ax.imshow(r)\n", " ax.axis('off')\n", " camera.snap()\n", " frames.append([frame])\n", " animation = camera.animate(blit=False, interval=50)\n", " animation.save('gecko.mp4')\n", " return animation, outputs, rgbs" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 14, "source": [ "animation, outputs, rgbs = render_nca_steps(nca, params, num_steps=256)" ], "outputs": [ { "output_type": "display_data", "data": { "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "c0e2eeba79a046e1a9e1b56275e1c911" }, "text/html": [ "\n", "
\n", "
\n", " Animation\n", "
\n", " \n", "
\n", " " ], "image/png": "iVBORw0KGgoAAAANSUhEUgAAArwAAAH0CAYAAADfWf7fAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAZiElEQVR4nO3da5CdB33f8efs2dtZ7a7usmTJsnzD2JDY1Da1zd1JS4ZJgdIEDDEeksk4aWBat9M2tA2UAiE0wJRpMi6lbQopbV2mIVxKU8CkwfgSjI2NjS2IL9gW1sWyZGm1q72ePX3RN5nJ78lEjF3Zf30+L7+Wzh6tNKufnpn9uzMYDAYNAAAUNXSy3wAAADybDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKG34ZL8BgD+vvzqIfXGwFPtQy7/bx4ZGYu90frz3BcDzlye8AACUZvACAFCawQsAQGkGLwAApRm8AACU5koDcFKstPSlzmrs777v9/JPGOQf/8EX/1LsQ5387/zJ7ljLOwLg+c4TXgAASjN4AQAozeAFAKA0gxcAgNIMXgAASnOlAXhWtV1jaPvX9q/c+dHYP7P/T2IfHR6P/dbjfxb71y77QOyrLVcajvTnYt/QXRM7AM89nvACAFCawQsAQGkGLwAApRm8AACUZvACAFBaZzAYDE72mwDqWm3pn997a+xv/85HYh8aH4199vhs/gDD+d/zb97+stjff94vxH5+b2d+/U7OADz3eMILAEBpBi8AAKUZvAAAlGbwAgBQmsELAEBprjQAz4iFJn8pWe0vx77tK2/LrzPSctdhqR9zfyT/u31qJff5lrsRn7zonbG/ZfsrY2+7PtFr8jUJAE4eT3gBACjN4AUAoDSDFwCA0gxeAABKM3gBACht+GS/AaCG8aYT+w17bop9rr8Y++hy/nf4Qn+p5ePmqwgLs/OxD031Yn/Xdz8R+1WnXRz7juFNsZ+o1ZbrFu3y5/lEf/SJvQrA85snvAAAlGbwAgBQmsELAEBpBi8AAKUZvAAAlOZKA3BC2m4KtPXfefCLsXefXol9ea7lesO68dhXWq49NE/MxDz2hfvy6/zk5tg/tutzsf/mhdfGPtHk97na9GM/unI89pnF3A8uHYt94/ja2Hf18q8L4FTiCS8AAKUZvAAAlGbwAgBQmsELAEBpBi8AAKW50gCckLZrDI/O7Yv90OLh2Dddf0fss6PLsS/9k4vz+zlzIvbukXzlYHkmXzno7l6N/SsP3xb7Ry68Lvamn19nKB9paC7941/LLzPWjX1hkK9bXLv11bF/6MJfzu8nv81myGMQoCBf2gAAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM2VBuCEdFquBBzr5H5w4enYh8fzj1+/bWPsM+vyl6uhybHcLz0t9qWWqw7NYj6jcLRZyD++xWx3KfbvHn4w9jPGN8T+xEr+vI22fNn+T4/flF9n7qnY/8Ml/yj2nr8WgII84QUAoDSDFwCA0gxeAABKM3gBACjN4AUAoDTfjguckEHLl401w73Yx3vTsS9/+lWxH39qkD/wwvGYp1YWYx9bl68xHO7kl+8fz9cV5non9lzgwaUDsf/yff869r1PPxn7Uid/Hpa7+eNuGpmK/Q2bXxr7kMcdwCnElzwAAEozeAEAKM3gBQCgNIMXAIDSDF4AAEpzpQE4IYOW6wGnjeRrDOeNboz94dHZ2Ec3zsc+cij/+/y01Zibbj//h3Xr8/WGx1bz+Ya18/njzi7PxX7R6I7Yb7z412P/4Lf/fexfPHxn7KdNbYn9ziv/TezrpzfE3l1oOVcxnjPA85knvAAAlGbwAgBQmsELAEBpBi8AAKUZvAAAlOZKA3BCuk3+7v5eyzf9v3HbFbF/YP/n8+sP9WJftzlfXXjJzELsE938hh5r+rF3pvN5gtkjy7FPrXTz64zkKxYXbDg79s+88n2x7/jPr4t9op/fz+a1m2If7rQ813iWrzHMDfL7HGp5P6NN/ny2PZVp+ePWDFbzn5POkOc7cCrzFQAAgNIMXgAASjN4AQAozeAFAKA0gxcAgNJcaQCeEcP9/O/nd537+tj/YO/NsT+wvC/2pZbvsr+kpV9xcCV/3DPyNYD13dzvH12K/YHlJ2K/sHdW7J0mXy24Zc/dsc89fTT2wWr+9R6YOxT79snNsT/bhlrOKDw8uyf2R+f3xz45nK92LK7k35dXrntx7L2W6x/AqcETXgAASjN4AQAozeAFAKA0gxcAgNIMXgAASusMBoP8P34HeAYMmvwl5pGj+crBxTf/auz9yfz6Fyytxn71lqnYp2dmYt/fcu3h7mY89kdX18R+/Xlvjf3+Pbtj/8yt/yP2g8fy1YXpiXWxP3T912LfMDId+4ma7y/GPtTNx34WmnxF4bo7Pxb7l578Vv7Ag/z7u66ZiP3el98Q+8bpTbF3Oi3nJIBSPOEFAKA0gxcAgNIMXgAASjN4AQAozeAFAKA0VxqAZ9VKS58Z5O/iPzC7N/Yrbn5n7CNNvh5wxbpe7H9nNL+fC/fNx37PmZtj/29z+Vf2wIG52OdH8rWBxaePxX7OsXyW4n++Pl8hOGPrzthXhvMVgvEmX1doM+jnvyq+fuCO2P/2t9+X+6aXxb5n9XDsc4v592W25QrEttV8neMLr/pg7JNjLec/gFI84QUAoDSDFwCA0gxeAABKM3gBACjN4AUAoLQT+zZdgBZt517ybYKmmejk6wFnTm2P/ZbXfCL219x0Xez7+vmKwsGRNbGPHM/XAF53x2OxP3rJ6bEvr81XAh6Zn419Zku+JvHIjunYP71wc+zXPP3q2LdtPjN/3JYrB5Or+a+Ffjf/Du+Y3BL7lt6G2O/q58/n4YV8rWIwyH9OxkbzuY29zUzsQ0Oe78CpzFcAAABKM3gBACjN4AUAoDSDFwCA0gxeAABK6wwGg7Zvrgb4K+u39Pff/6nYx5v83ffvv/+/xP7km78c+wd2/37sHz/4udivXTMW+yXd/CtYu2Uy9qPH5mKfnR2J/fa55djvWcx9ppuvEBwfyZ+3yXycoPnKlf8q9rN7W2O/+cDdsf/xnjtiP218few71uTX//juP4j9e8P7Y1/Kn4bmN85+S+z/4Iw3xD7VWxv7cMufQ6AWT3gBACjN4AUAoDSDFwCA0gxeAABKM3gBACgt/0/TAVqs9PNVgZVOvnLw2Yf/d+zHVuZjP723Lvbh1YXYrz7rqtg/vTd/3HuO5dfprc3frX/hYv51vfR7h2M///F8LmHyyrNy7+QvwwtPHo/94al8ZeLekaXY3/Sdfxb7h8/6u7H//T/8h7EfGMrvZ2J4OvZObyr25S357MLSaP78XzZ9QezX78jXGCbG81UN1xjg1OYJLwAApRm8AACUZvACAFCawQsAQGkGLwAApbnSAJyQoU439jfd8huxbxxbG/uh5dnY9/WPxP6NfffG/tPbr4x9Wz9fD3h0ZSX2DcfzNYbtE4PYn1o7EfvwRfnjnnVuy49fGYn9Z277Qey7N+QrBO95Yb5+sKflGsZbH/hw7J1mMfZmIV/nmBjkz9v568+I/chwfs7S6+Vf1zVbXpE/7lgv9rGWqxfAqc0TXgAASjN4AQAozeAFAKA0gxcAgNIMXgAASvPtrMCJyUcLmtfvfFXsv/vg52Kfy0cFmsHxfCXgS0/cHvtQpxP7Az/6fuy9Tevyj5/I3/U/fSRfdeidPhX77Pqx2F9w31OxL1+6I/a9l26Mfe14/g147XS+cnDXbL7S8O2hfB1i+bStsc88cSj2G3/+htiv2PSi2Jux/HG7w7kPNfn3d6jreQ3wV+crBgAApRm8AACUZvACAFCawQsAQGkGLwAApXUGg0HL91wD/EWry/3YH+/n7+L/iZt+JfaXTp8T+91P7Y796bn8+pOzq7Ef7+brCpu27oz9qen869qwmvtPTOQzE780nI/fXP3FP4v9a792fuzzs/k6wZbF/JziaMvji6Xd+fP2neX8pf9TW9fEfmjfsdi/9ap8peGcjWfEPjGaXx/g2eQJLwAApRm8AACUZvACAFCawQsAQGkGLwAApRm8AACUlu/nALQZ6ca8uTMV+yde8vdi/7nTXhb7v33oD2N/93f/XX4/q7M5H5+L/UOv+YXYN27K58re/KcfjP3Ohfxxhyd7se/+2XyGbXIlfxm+qJ9f/6Iv7In9lp//ydhf0clnwDa2nJf707F8bu0HuzbHftfCQ7G/oHtm7AAngye8AACUZvACAFCawQsAQGkGLwAApRm8AACU1hkMBoOT/SaA57/FwUrsQ51O7EuDfCVgtJ//HX7NHflawm0H7439F7f9zdjfffE1+f108/WJe44+GPvPfOO9sfd25WsVEzMHY3/h6HTs5w4txn7l6ELsm8dOj70/nT/Pex85Gvuefv79ums0X2842l8X+22vuCH21ZY/D0NN/vwDPBM84QUAoDSDFwCA0gxeAABKM3gBACjN4AUAoDRXGoDnlJUmf0labekr/eXY265DjA3lawP9Qb6KsLqanwscHByL/cYHvhj7R777+7Ef3ZLfz+n56EVz4fa1sb+ytxT7mWO92LvH8uftnN3zsX9y03Ds3xzP7/+f77wu9rfs+KnYO03+/QJ4JnjCCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaa40APw4VvKVg73H98f+6hvfEfuDg4OxD593Ruyblruxv2IqP7/46XX5SsMZx/NVihc90Y/966P54/7XdfmvkInBubF//q//ZuydoZHYm5ZrGwAnwhNeAABKM3gBACjN4AUAoDSDFwCA0gxeAABKy/9zdAD+Usv9pdjf/NVfj/1HK8di787n1+kdOB771m35+sE3534U+0LLkYM3rs/POxbOH419eF++6rB2Il9XuPfwI7E/uvhk7LsmdsTuRgPwTPCEFwCA0gxeAABKM3gBACjN4AUAoDSDFwCA0lxpAGiaZqFZjX11OV9LmBlaiP2t57429rtn8xWF/syh2C/YcFbsn738fbHfPv947O958OOxr5+dj/1vjfZjn1g7EfvUsfz5aXr5ecp9R34Y+66J7fl1BjnPr+arEb3ueP4JwCnNE14AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEpzpQGgaZql/lLsb7/rw7H/3PaXxb5pakvs3bF8bmD9ztNjv395f+xvuPVfxn7bT90Q+94zron9U9//ROyXr52O/W/88Ejs/2fjaOzjw53Y7zv2aOyvb14e+5eevDP2l/byFYvetCsNwF/kCS8AAKUZvAAAlGbwAgBQmsELAEBpBi8AAKW50gDQNM3KoB/7zfvuif3x2Sdjf+ypH8be6Y3EvjzI1xsW8w9vHmq53jDf5CsT1257bewfvf8/xv7thYX8gU/Lf10MtTw3mVjIn88fze7Nr9/isqlzY984vOaEXgc4tXnCCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaa40ADRNMzo8Fvs/ffHbY/9fB+6Ifc3a9bE/sXwo9omJjbH/9tlvi/3qs14d+1jLl/ORfu5XT14e+5cHt+XXGYzGvnRsLvbVtfmKwtzqfOz9wWrsWyfy5xPgRHjCCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaa40ADRNM9J0Y//ko38U+97FfHXh+nPeEPt/f+KbsZ89vTP2q6ZfHHt3sRP75Hi+orDazdcPrr3o6thvvOkbsd+/Pj8f2bBxbewjy8uxz3fy6wwt9WNfHs2/L0urS7GPdvOPb/v9PbaUr0xMjeYrE8Dzkye8AACUZvACAFCawQsAQGkGLwAApRm8AACU5koDQNM0nWYQ+5+8/GOx33b0/tjfcedHY79o/dmx75s7EPvOqa2x98bGY2+z1PLruqC3K/Zdo9tif7h/JPa9M0/HvjIxld/Pkf2xz4/F3Ay3vP8DC/lKxg8W98X+wvHTY98+sSn2hdV8ZaJpmmZ8aKT1vwHPTZ7wAgBQmsELAEBpBi8AAKUZvAAAlGbwAgBQmisNAE3TDOdjAM3IWP4yeV4vf9f/G7ddGfvOsc2xH1o+lj9utxv7eHNiFwLGm/w6yy3XD9552bWx/+Jtvx37+pbHJjNzx2MfDD8V+/Hl+diP9g/Hfu6t1+UPPNLy19pS/vVev/11sX/oRe/Ir9M0TX+QX6vb6bT+HODk8oQXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSOoNBy7ebAtAsDJZj77R8R/5Q09bz84W2L8BtX5lHnqFLACstH3m56cd+4dffHvvMfL4ysbK0FPvi+Hjsn7/8vbFfPnl+7N9f3Bf7LYd2x/6PH/y92JuV/Ov96sX/Iv/4pmmu2npp7K40wHOXJ7wAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJTW8j8dB6Bpmma8M3JyPvCz/A3/w63XJPJfCx99ybtif9utvxX7+jVTsc8cz1cdbtjzR7Hv2Lkh9sunzo790u3nxP5bD3829sNDM7Hf8/TDsTdN01y19ZKW/+JKAzxXecILAEBpBi8AAKUZvAAAlGbwAgBQmsELAEBprjQA8Of0Y/3ZDZfF/qYzr4j9xv3fir3bz5cMvvTYrbF38ttpdnXWxn5g7kjswy0HFDr9Qexbevk6RNM0TWeQf44jDfDc5QkvAAClGbwAAJRm8AIAUJrBCwBAaQYvAACldQaDtm83BYD/59hgPvZOJ58muOKrvxr7QyuHY184nl9/bHIy9olON/a5ZiX2TsvfdC/obYv9yy95b/4JTdNsndwS+8hQfk/AyecJLwAApRm8AACUZvACAFCawQsAQGkGLwAApbnSAMCPbbHlKsKhxUOxv+n298f+vYX9sc/NHYl9TS9fb+jmoxHN5qG1sX/hkvfEfu76XfmFmqYZ7o7kj936M4CTzRNeAABKM3gBACjN4AUAoDSDFwCA0gxeAABKc6UBgB/bwmAh9k4n3ywYabllcPfRR2Lff2hf7Lcv5B9/3uTpsb9uw1+LfcOaDbH/ZX8zDg+1nIIAnrM84QUAoDSDFwCA0gxeAABKM3gBACjN4AUAoDRXGgD4/6bf9GNfbfmbaGZ1PvZ1q2P59bv5gsJwy3WIIRcX4JTgCS8AAKUZvAAAlGbwAgBQmsELAEBpBi8AAKW50gAAQGme8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaf8XdahW29FwEN8AAAAASUVORK5CYII=", "text/plain": [ "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …" ] }, "metadata": {} } ], "metadata": {} }, { "cell_type": "code", "execution_count": 15, "source": [ "from IPython.display import Video\n", "\n", "Video(\"gecko.mp4\")" ], "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "execution_count": 15 } ], "metadata": {} } ], "metadata": { "orig_nbformat": 4, "language_info": { "name": "python", "version": "3.9.0", "mimetype": "text/x-python", "codemirror_mode": { "name": "ipython", "version": 3 }, "pygments_lexer": "ipython3", "nbconvert_exporter": "python", "file_extension": ".py" }, "kernelspec": { "name": "python3", "display_name": "Python 3.9.0 64-bit ('jax_gpu': conda)" }, "interpreter": { "hash": "a7271dcc4a91420ffb9cc5ce7ff5a5d83d948f729c0ba20dec48f9a748a86390" } }, "nbformat": 4, "nbformat_minor": 2 }