{ "cells": [ { "cell_type": "code", "execution_count": 1, "source": [ "import abc\n", "import random\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import tensorflow as tf\n", "from einops import repeat, rearrange\n", "tf.config.experimental.set_visible_devices([], 'GPU')\n", "\n", "# uncomment this to enable jax gpu preallocation, might lead to memory issues\n", "\n", "import os\n", "os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"false\"" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "# Gecko" ], "metadata": {} }, { "cell_type": "code", "execution_count": 2, "source": [ "from jax_nca.dataset import ImageDataset" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 3, "source": [ "dataset = ImageDataset(emoji='🦎', img_size=64)" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 4, "source": [ "dataset.visualize()" ], "outputs": [ { "output_type": "display_data", "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" } } ], "metadata": {} }, { "cell_type": "code", "execution_count": 5, "source": [ "from jax_nca.nca import NCA" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "### NCA\n", "- num_hidden_channels = 16\n", "- num_target_channels = 3\n", "- cell_fire_rate = 1.0 (100% chance for cells to be updated)\n", "- alpha_living_threshold = 0.1 (threshold for cells to be alive)" ], "metadata": {} }, { "cell_type": "code", "execution_count": 6, "source": [ "nca = NCA(16, 3, trainable_perception=False, cell_fire_rate=1.0, alpha_living_threshold=0.1)" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 7, "source": [ "from jax_nca.trainer import EmojiTrainer" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 8, "source": [ "trainer = EmojiTrainer(dataset, nca, n_damage=0)" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 9, "source": [ "# trainer.train(100000, batch_size=8, seed=10, lr=2e-4, min_steps=64, max_steps=96)" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "#### Get current state from trainer" ], "metadata": {} }, { "cell_type": "code", "execution_count": 10, "source": [ "state = trainer.state" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 11, "source": [ "# save\n", "# nca.save(state.params, \"saved_params\")" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 12, "source": [ "params = nca.load(\"gecko_100_cell_fire_rate\")" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 13, "source": [ "import numpy as np\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "%matplotlib ipympl\n", "\n", "plt.style.use('ggplot')\n", "# Imports specifically so we can render outputs in Jupyter.\n", "from JSAnimation.IPython_display import display_animation\n", "from matplotlib import animation\n", "from IPython.display import display\n", "from celluloid import Camera\n", "from IPython.display import HTML\n", "import jax\n", "import jax.numpy as jnp\n", "\n", "def render_nca_steps(nca, params, shape = (64, 64), num_steps = 2):\n", " nca_seed = nca.create_seed(nca.num_hidden_channels, nca.num_target_channels, shape=shape, batch_size=1)\n", " rng = jax.random.PRNGKey(0)\n", " _, outputs = nca.multi_step(params, nca_seed, rng, num_steps=num_steps)\n", " stacked = jnp.squeeze(jnp.stack(outputs))\n", " rgbs = np.array(nca.to_rgb(stacked))\n", "\n", " fig = plt.figure(\"Animation\",figsize=(7,5))\n", " camera = Camera(fig)\n", " ax = fig.add_subplot(111)\n", " frames = []\n", " for r in rgbs:\n", " frame = ax.imshow(r)\n", " ax.axis('off')\n", " camera.snap()\n", " frames.append([frame])\n", " animation = camera.animate(blit=False, interval=50)\n", " animation.save('gecko.mp4')\n", " return animation, outputs, rgbs" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 14, "source": [ "animation, outputs, rgbs = render_nca_steps(nca, params, num_steps=256)" ], "outputs": [ { "output_type": "display_data", "data": { "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, "model_id": "c0e2eeba79a046e1a9e1b56275e1c911" }, "text/html": [ "\n", "
\n", "
\n", " Animation\n", "
\n", " \n", "
\n", " " ], "image/png": "", "text/plain": [ "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …" ] }, "metadata": {} } ], "metadata": {} }, { "cell_type": "code", "execution_count": 15, "source": [ "from IPython.display import Video\n", "\n", "Video(\"gecko.mp4\")" ], "outputs": [ { "output_type": "execute_result", "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "execution_count": 15 } ], "metadata": {} } ], "metadata": { "orig_nbformat": 4, "language_info": { "name": "python", "version": "3.9.0", "mimetype": "text/x-python", "codemirror_mode": { "name": "ipython", "version": 3 }, "pygments_lexer": "ipython3", "nbconvert_exporter": "python", "file_extension": ".py" }, "kernelspec": { "name": "python3", "display_name": "Python 3.9.0 64-bit ('jax_gpu': conda)" }, "interpreter": { "hash": "a7271dcc4a91420ffb9cc5ce7ff5a5d83d948f729c0ba20dec48f9a748a86390" } }, "nbformat": 4, "nbformat_minor": 2 }