shyamsn97 commited on
Commit
434b57f
1 Parent(s): 4b71f88

first commit

Browse files
.flake8 ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ [flake8]
2
+ ignore = E203, E266, E501, W503, C901 E731
3
+ # line length is intentionally set to 80 here because black uses Bugbear
4
+ # See https://github.com/psf/black/blob/master/README.md#line-length for more details
5
+ max-line-length = 1000
6
+ max-complexity = 18
7
+ select = B,C,E,F,W,T4,B9
8
+ # We need to configure the mypy.ini because the flake8-mypy's default
9
+ # options don't properly override it, so if we don't specify it we get
10
+ # half of the config from mypy.ini and half from flake8-mypy.
11
+ mypy_config = mypy.ini
.gitignore ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.py[cod]
2
+
3
+ # C extensions
4
+ *.so
5
+
6
+ # Packages
7
+ *.egg
8
+ *.egg-info
9
+ dist
10
+ build
11
+ eggs
12
+ parts
13
+ bin
14
+ var
15
+ sdist
16
+ develop-eggs
17
+ .installed.cfg
18
+ lib
19
+ lib64
20
+ __pycache__
21
+
22
+ # Installer logs
23
+ pip-log.txt
24
+
25
+ # Unit test / coverage reports
26
+ .coverage
27
+ .tox
28
+ nosetests.xml
29
+
30
+ # Translations
31
+ *.mo
32
+
33
+ # Mr Developer
34
+ .mr.developer.cfg
35
+ .project
36
+ .pydevproject
37
+ test.json
38
+ *.pickle
39
+ venv
40
+ .idea
41
+ *.vscode/
42
+
43
+ #notebooks
44
+ */**/.ipynb_checkpoints/*
45
+
46
+ # logs
47
+ */**/checkpoints/*
48
+ */**/mlruns/*
49
+ */**/tensorboard_logs/*
50
+ */**/wandb/*
51
+ checkpoints/*
52
+ wandb/*
53
+ mlruns/*
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The MIT License (MIT)
2
+
3
+ Copyright (c) 2021 Shyam Sudhakaran
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
Makefile ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ clean: clean-build clean-pyc clean-test ## remove all build, test, coverage and Python artifacts
2
+
3
+ clean-build: ## remove build artifacts
4
+ rm -fr build/
5
+ rm -fr dist/
6
+ rm -fr .eggs/
7
+ find . -name '*.egg-info' -exec rm -fr {} +
8
+ find . -name '*.egg' -exec rm -f {} +
9
+
10
+ clean-pyc: ## remove Python file artifacts
11
+ find . -name '*.pyc' -exec rm -f {} +
12
+ find . -name '*.pyo' -exec rm -f {} +
13
+ find . -name '*~' -exec rm -f {} +
14
+ find . -name '__pycache__' -exec rm -fr {} +
15
+
16
+ clean-test: ## remove test and coverage artifacts
17
+ rm -fr .tox/
18
+ rm -f .coverage
19
+ rm -fr coverage/
20
+ rm -fr .pytest_cache
21
+
22
+ lint: ## check style with flake8
23
+ isort --profile black jax_nca
24
+ black jax_nca
25
+ flake8 jax_nca
26
+
27
+ install: clean lint
28
+ python setup.py install
README.md ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Neural Cellular Automata (Based on https://distill.pub/2020/growing-ca/) implemented in Jax (Flax)
2
+
3
+ ![Gecko gif](https://raw.githubusercontent.com/shyamsn97/jax-nca/main/images/gecko.gif?token=GHSAT0AAAAAABTB4G7FLAJSLDHSIOQONS3IYTB5ZEA)
4
+
5
+ ---
6
+
7
+
8
+ ## Installation
9
+ from source:
10
+ ```
11
+ git clone git@github.com:shyamsn97/jax-nca.git
12
+ cd jax-nca
13
+ python setup.py install
14
+ ```
15
+
16
+ from PYPI
17
+ ```
18
+ pip install jax-nca
19
+ ```
20
+ ---
21
+
22
+ ## How do NCAs work?
23
+ For more information, view the awesome article https://distill.pub/2020/growing-ca/ -- Mordvintsev, et al., "Growing Neural Cellular Automata", Distill, 2020
24
+
25
+ Image below describes a single update step: https://github.com/distillpub/post--growing-ca/blob/master/public/figures/model.svg
26
+
27
+ ![NCA update](https://raw.githubusercontent.com/shyamsn97/jax-nca/main/images/model.svg?token=GHSAT0AAAAAABTB4G7FOWOPXEUYVLBGRNSWYTB5YUA)
28
+
29
+ ---
30
+
31
+ ## Why Jax?
32
+
33
+ <b> Note: This project served as a nice introduction to jax, so its performance can probably be improved </b>
34
+
35
+ NCAs are autoregressive models like RNNs, where new states are calculated from previous ones. With jax, we can make these operations a lot more performant with `jax.lax.scan` and `jax.jit` (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html)
36
+
37
+ Instead of writing the nca growth process as:
38
+
39
+ ```python
40
+ def multi_step(params, nca, current_state, num_steps):
41
+ # params: parameters for NCA
42
+ # nca: Flax Module describing NCA
43
+ # current_state: Current NCA state
44
+ # num_steps: number of steps to run
45
+
46
+ for i in range(num_steps):
47
+ current_state = nca.apply(params, current_state)
48
+ return current_state
49
+ ```
50
+
51
+ We can write this with `jax.lax.scan`
52
+
53
+ ```python
54
+ def multi_step(params, nca, current_state, num_steps):
55
+ # params: parameters for NCA
56
+ # nca: Flax Module describing NCA
57
+ # current_state: Current NCA state
58
+ # num_steps: number of steps to run
59
+
60
+ def forward(carry, inp):
61
+ carry = nca.apply({"params": params}, carry)
62
+ return carry, carry
63
+
64
+ final_state, nca_states = jax.lax.scan(forward, current_state, None, length=num_steps)
65
+ return final_state
66
+ ```
67
+ The actual multi_step implementation can be found here: https://github.com/shyamsn97/jax-nca/blob/main/jax_nca/nca.py#L103
68
+
69
+ ---
70
+
71
+ ## Usage
72
+ See [notebooks/Gecko.ipynb](notebooks/Gecko.ipynb) for a full example
73
+
74
+ <b> Currently there's a bug with the stochastic update, so only `cell_fire_rate = 1.0` works at the moment </b>
75
+
76
+ Creating and using NCA:
77
+
78
+ ```python
79
+ class NCA(nn.Module):
80
+ num_hidden_channels: int
81
+ num_target_channels: int = 3
82
+ alpha_living_threshold: float = 0.1
83
+ cell_fire_rate: float = 1.0
84
+ trainable_perception: bool = False
85
+ alpha: float = 1.0
86
+
87
+ """
88
+ num_hidden_channels: Number of hidden channels for each cell to use
89
+ num_target_channels: Number of target channels to be used
90
+ alpha_living_threshold: threshold to determine whether a cell lives or dies
91
+ cell_fire_rate: probability that a cell receives an update per step
92
+ trainable_perception: if true, instead of using sobel filters use a trainable conv net
93
+ alpha: scalar value to be multiplied to updates
94
+ """
95
+ ...
96
+
97
+ from jax_nca.nca import NCA
98
+
99
+ # usage
100
+ nca = NCA(
101
+ num_hidden_channels = 16,
102
+ num_target_channels = 3,
103
+ trainable_perception = False,
104
+ cell_fire_rate = 1.0,
105
+ alpha_living_threshold = 0.1
106
+ )
107
+
108
+ nca_seed = nca.create_seed(
109
+ nca.num_hidden_channels, nca.num_target_channels, shape=(64,64), batch_size=1
110
+ )
111
+ rng = jax.random.PRNGKey(0)
112
+ params = = nca.init(rng, nca_seed, rng)["params"]
113
+ update = nca.apply({"params":params}, nca_seed, jax.random.PRNGKey(10))
114
+
115
+ # multi step
116
+
117
+ final_state, nca_states = nca.multi_step(poarams, nca_seed, jax.random.PRNGKey(10), num_steps=32)
118
+ ```
119
+
120
+ To train the NCA:
121
+ ```python
122
+ from jax_nca.dataset import ImageDataset
123
+ from jax_nca.trainer import EmojiTrainer
124
+
125
+
126
+ dataset = ImageDataset(emoji='🦎', img_size=64)
127
+
128
+
129
+ nca = NCA(
130
+ num_hidden_channels = 16,
131
+ num_target_channels = 3,
132
+ trainable_perception = False,
133
+ cell_fire_rate = 1.0,
134
+ alpha_living_threshold = 0.1
135
+ )
136
+
137
+ trainer = EmojiTrainer(dataset, nca, n_damage=0)
138
+
139
+ trainer.train(100000, batch_size=8, seed=10, lr=2e-4, min_steps=64, max_steps=96)
140
+
141
+ # to access train state:
142
+
143
+ state = trainer.state
144
+
145
+ # save
146
+ nca.save(state.params, "saved_params")
147
+
148
+ # load params
149
+ loaded_params = nca.load("saved_params")
150
+
151
+ ```
images/gecko.gif ADDED
images/model.svg ADDED
jax_nca/__init__.py ADDED
File without changes
jax_nca/dataset.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ from einops import repeat
4
+
5
+ from jax_nca.utils import load_emoji
6
+
7
+
8
+ def to_alpha(x):
9
+ return np.clip(x[:, :, :, 3:4], 0.0, 1.0)
10
+
11
+
12
+ def rgb(x, rgb=False):
13
+ # assume rgb premultiplied by alpha
14
+ if rgb:
15
+ return np.clip(x[:, :, :, :3], 0.0, 1.0)
16
+ rgb, a = x[:, :, :, :3], to_alpha(x)
17
+ return np.clip(1.0 - a + rgb, 0.0, 1.0)
18
+
19
+
20
+ class ImageDataset:
21
+ def __init__(self, emoji: str = None, img: np.array = None, img_size: int = 64):
22
+ if img is None:
23
+ img = load_emoji(emoji, img_size)
24
+ self.rgb = img.shape[-1] == 3
25
+ self.img_shape = img.shape
26
+ self.img = np.expand_dims(img, 0) # (b w h c)
27
+ self.rgb_img = rgb(self.img, self.rgb)
28
+
29
+ def get_batch(self, batch_size: int = 1):
30
+ return repeat(
31
+ self.img, "b w h c -> (b repeat) w h c", repeat=batch_size
32
+ ), repeat(self.rgb_img, "b w h c -> (b repeat) w h c", repeat=batch_size)
33
+
34
+ def visualize(self):
35
+ _ = plt.imshow(self.rgb_img[0])
36
+ plt.show()
jax_nca/nca.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ from typing import Tuple
3
+
4
+ import flax.linen as nn
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import numpy as np
8
+ from flax import serialization
9
+ from jax import lax
10
+
11
+
12
+ class SobelPerceptionNet(nn.Module):
13
+ @nn.compact
14
+ def __call__(self, x):
15
+ # x shape - BHWC
16
+
17
+ num_channels = x.shape[-1]
18
+
19
+ # 2D sobel kernels - IOHW layout
20
+
21
+ x_sobel_kernel = jnp.zeros(
22
+ (num_channels, num_channels, 3, 3), dtype=jnp.float32
23
+ )
24
+ x_sobel_kernel += (
25
+ jnp.array([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]])[
26
+ jnp.newaxis, jnp.newaxis, :, :
27
+ ]
28
+ / 8.0
29
+ )
30
+
31
+ y_sobel_kernel = jnp.zeros(
32
+ (num_channels, num_channels, 3, 3), dtype=jnp.float32
33
+ )
34
+ y_sobel_kernel += (
35
+ jnp.array([[-1.0, -2.0, -1.0], [0.0, 0.0, 0.0], [1.0, 2.0, 1.0]])[
36
+ jnp.newaxis, jnp.newaxis, :, :
37
+ ]
38
+ / 8.0
39
+ )
40
+ x = jnp.transpose(x, [0, 3, 1, 2]) # N C H W
41
+
42
+ x_out = lax.conv(
43
+ x, # lhs = NCHW image tensor
44
+ x_sobel_kernel, # rhs = OIHW conv kernel tensor
45
+ (1, 1), # window strides
46
+ "SAME",
47
+ ) # padding mode
48
+
49
+ y_out = lax.conv(
50
+ x, # lhs = NCHW image tensor
51
+ y_sobel_kernel, # rhs = OIHW conv kernel tensor
52
+ (1, 1), # window strides
53
+ "SAME",
54
+ ) # padding mode
55
+
56
+ out = jnp.concatenate([x, x_out, y_out], axis=1)
57
+ return jnp.transpose(out, [0, 2, 3, 1]) # N H W C
58
+
59
+
60
+ class UpdateNet(nn.Module):
61
+
62
+ num_channels: int
63
+
64
+ @nn.compact
65
+ def __call__(self, x):
66
+ update_layer_1 = nn.Conv(
67
+ features=64, kernel_size=(1, 1), strides=1, padding="VALID"
68
+ )
69
+ update_layer_2 = nn.Conv(
70
+ features=64, kernel_size=(1, 1), strides=1, padding="VALID"
71
+ )
72
+ update_layer_3 = nn.Conv(
73
+ features=self.num_channels,
74
+ kernel_size=(1, 1),
75
+ strides=1,
76
+ padding="VALID",
77
+ kernel_init=jax.nn.initializers.zeros,
78
+ use_bias=False,
79
+ )
80
+ x = update_layer_1(x)
81
+ x = nn.relu(x)
82
+ x = update_layer_2(x)
83
+ x = nn.relu(x)
84
+ x = update_layer_3(x)
85
+ return x
86
+
87
+
88
+ class TrainablePerception(nn.Module):
89
+ num_channels: int
90
+
91
+ @nn.compact
92
+ def __call__(self, x):
93
+ out = nn.Conv(
94
+ features=self.num_channels * 3,
95
+ kernel_size=(3, 3),
96
+ use_bias=False,
97
+ feature_group_count=self.num_channels,
98
+ )(x)
99
+ return out
100
+
101
+
102
+ @functools.partial(jax.jit, static_argnames=("apply_fn", "num_steps"))
103
+ def nca_multi_step(apply_fn, params, current_state: jnp.array, rng, num_steps: int):
104
+ def forward(carry, inp):
105
+ carry = apply_fn({"params": params}, carry, rng)
106
+ return carry, carry
107
+
108
+ x, outs = jax.lax.scan(forward, current_state, None, length=num_steps)
109
+ return x, outs
110
+
111
+
112
+ class NCA(nn.Module):
113
+ num_hidden_channels: int
114
+ num_target_channels: int = 3
115
+ alpha_living_threshold: float = 0.1
116
+ cell_fire_rate: float = 1.0
117
+ trainable_perception: bool = False
118
+ alpha: float = 1.0
119
+
120
+ """
121
+ num_hidden_channels: Number of hidden channels for each cell to use
122
+ num_target_channels: Number of target channels to be used
123
+ alpha_living_threshold: threshold to determine whether a cell lives or dies
124
+ cell_fire_rate: probability that a cell receives an update per step
125
+ trainable_perception: if true, instead of using sobel filters use a trainable conv net
126
+ alpha: scalar value to be multiplied to updates
127
+ """
128
+
129
+ @classmethod
130
+ def create_seed(
131
+ cls,
132
+ num_hidden_channels: int,
133
+ num_target_channels: int = 3,
134
+ shape: Tuple[int] = (48, 48),
135
+ batch_size: int = 1,
136
+ ):
137
+ seed = np.zeros((batch_size, *shape, num_hidden_channels + 3 + 1))
138
+ w, h = seed.shape[1], seed.shape[2]
139
+ seed[:, w // 2, h // 2, 3:] = 1.0
140
+ return seed
141
+
142
+ def setup(self):
143
+ num_channels = 3 + self.num_hidden_channels + 1
144
+ if self.trainable_perception:
145
+ self.perception = TrainablePerception(num_channels)
146
+ else:
147
+ self.perception = SobelPerceptionNet()
148
+ self.update_net = UpdateNet(num_channels)
149
+
150
+ def alive(self, x, alpha_living_threshold: float):
151
+ return (
152
+ nn.max_pool(
153
+ x[..., 3:4], window_shape=(3, 3), strides=(1, 1), padding="SAME"
154
+ )
155
+ > alpha_living_threshold
156
+ )
157
+
158
+ def get_stochastic_update_mask(self, x, rng, cell_fire_rate: float = 1.0):
159
+ return jnp.array(np.random.uniform(size=x[..., :1].shape) <= cell_fire_rate)
160
+
161
+ def __call__(self, x, rng):
162
+ pre_life_mask = self.alive(x, self.alpha_living_threshold)
163
+
164
+ perception_out = self.perception(x)
165
+ update = self.alpha * jnp.reshape(self.update_net(perception_out), x.shape)
166
+
167
+ if self.cell_fire_rate >= 1.0:
168
+ stochastic_update_mask = self.get_stochastic_update_mask(
169
+ x, rng, self.cell_fire_rate
170
+ ).astype(float)
171
+ x = x + update * stochastic_update_mask
172
+ else:
173
+ x = x + update
174
+
175
+ post_life_mask = self.alive(x, self.alpha_living_threshold)
176
+
177
+ life_mask = pre_life_mask & post_life_mask
178
+ life_mask = life_mask.astype(float)
179
+
180
+ return x * life_mask
181
+
182
+ def save(self, params, path: str):
183
+ bytes_output = serialization.to_bytes(params)
184
+ with open(path, "wb") as f:
185
+ f.write(bytes_output)
186
+
187
+ def load(self, path: str):
188
+ nca_seed = self.create_seed(
189
+ self.num_hidden_channels, self.num_target_channels, batch_size=1
190
+ )
191
+ rng = jax.random.PRNGKey(0)
192
+ init_params = self.init(rng, nca_seed, rng)["params"]
193
+ with open(path, "rb") as f:
194
+ bytes_output = f.read()
195
+ return serialization.from_bytes(init_params, bytes_output)
196
+
197
+ def multi_step(self, params, current_state: jnp.array, rng, num_steps: int = 2):
198
+ return nca_multi_step(self.apply, params, current_state, rng, num_steps)
199
+
200
+ def to_rgb(self, x: jnp.array):
201
+ rgb, a = x[..., :3], jnp.clip(x[..., 3:4], 0.0, 1.0)
202
+ rgb = jnp.clip(1.0 - a + rgb, 0.0, 1.0)
203
+ return rgb
jax_nca/trainer.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import os
3
+ from collections.abc import Iterable
4
+ from datetime import datetime
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import numpy as np
9
+ import optax
10
+ import pandas as pd
11
+ import tqdm
12
+ from flax.training import train_state # Useful dataclass to keep train state
13
+ from tensorboardX import SummaryWriter
14
+
15
+ from jax_nca.utils import make_circle_masks
16
+
17
+
18
+ def get_tensorboard_logger(
19
+ experiment_name: str, base_log_path: str = "tensorboard_logs"
20
+ ):
21
+ log_path = "{}/{}_{}".format(base_log_path, experiment_name, datetime.now())
22
+ train_writer = SummaryWriter(log_path, flush_secs=10)
23
+ full_log_path = os.path.join(os.getcwd(), log_path)
24
+ print(
25
+ "Follow tensorboard logs with: python -m tensorboard.main --logdir '{}'".format(
26
+ full_log_path
27
+ )
28
+ )
29
+ return train_writer
30
+
31
+
32
+ def create_train_state(rng, nca, learning_rate, shape):
33
+ nca_seed = nca.create_seed(
34
+ nca.num_hidden_channels, nca.num_target_channels, shape=shape[:-1], batch_size=1
35
+ )
36
+ """Creates initial `TrainState`."""
37
+ params = nca.init(rng, nca_seed, rng)["params"]
38
+ tx = optax.chain(
39
+ # optax.clip_by_global_norm(10.0),
40
+ optax.adam(learning_rate),
41
+ )
42
+ return train_state.TrainState.create(apply_fn=nca.apply, params=params, tx=tx)
43
+
44
+
45
+ def clip_grad_norm(grad):
46
+ factor = 1.0 / (
47
+ jnp.linalg.norm(jax.tree_util.tree_leaves(jax.tree_map(jnp.linalg.norm, grad)))
48
+ + 1e-8
49
+ )
50
+ return jax.tree_map((lambda x: x * factor), grad)
51
+
52
+
53
+ @functools.partial(jax.jit, static_argnames=("apply_fn", "num_steps"))
54
+ def train_step(
55
+ apply_fn, state, seeds: jnp.array, targets: jnp.array, num_steps: int, rng
56
+ ):
57
+ def mse_loss(pred, y):
58
+ squared_diff = jnp.square(pred - y)
59
+ return jnp.mean(squared_diff, axis=[-3, -2, -1])
60
+
61
+ def loss_fn(params):
62
+ def forward(carry, inp):
63
+ carry = apply_fn({"params": params}, carry, rng)
64
+ return carry, carry
65
+
66
+ x, outs = jax.lax.scan(forward, seeds, None, length=num_steps)
67
+ rgb, a = x[..., :3], jnp.clip(x[..., 3:4], 0.0, 1.0)
68
+ rgb = jnp.clip(1.0 - a + rgb, 0.0, 1.0)
69
+
70
+ outs = jnp.transpose(outs, [1, 0, 2, 3, 4])
71
+ subset = outs[:, -8:] # B 12 H W C
72
+ return jnp.mean(
73
+ jax.vmap(mse_loss)(subset[..., :4], jnp.expand_dims(targets, 1))
74
+ ), (x, rgb)
75
+
76
+ grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
77
+ (loss, aux), grads = grad_fn(state.params)
78
+ grads = clip_grad_norm(grads)
79
+ updated, rgb = aux
80
+ return state.apply_gradients(grads=grads), loss, grads, updated, rgb
81
+
82
+
83
+ class SamplePool:
84
+ def __init__(self, max_size: int = 1000):
85
+ self.max_size = max_size
86
+ self.pool = [None] * max_size
87
+
88
+ def __getitem__(self, idx):
89
+ if isinstance(idx, Iterable):
90
+ return [self.pool[i] for i in idx]
91
+ return idx
92
+
93
+ def __setitem__(self, idx, v):
94
+ if isinstance(idx, Iterable):
95
+ for i in range(len(idx)):
96
+ index = idx[i]
97
+ self.pool[index] = v[i]
98
+ else:
99
+ self.pool[idx] = v
100
+
101
+ def sample(self, num_samples: int):
102
+ indices = np.random.randint(0, self.max_size, num_samples)
103
+ return self.__getitem__(indices), indices
104
+
105
+
106
+ def flatten(d):
107
+ df = pd.json_normalize(d, sep="_")
108
+ return df.to_dict(orient="records")[0]
109
+
110
+
111
+ class EmojiTrainer:
112
+ def __init__(self, dataset, nca, pool_size: int = 1024, n_damage: int = 0):
113
+ self.dataset = dataset
114
+ self.img_shape = self.dataset.img_shape
115
+ self.nca = nca
116
+ self.pool_size = pool_size
117
+ self.n_damage = n_damage
118
+ self.state = None
119
+
120
+ def train(
121
+ self,
122
+ num_epochs,
123
+ batch_size: int = 8,
124
+ seed: int = 10,
125
+ lr: float = 0.001,
126
+ min_steps: int = 64,
127
+ max_steps: int = 96,
128
+ ):
129
+ pool = SamplePool(self.pool_size)
130
+
131
+ writer = get_tensorboard_logger("EMOJITrainer")
132
+ rng = jax.random.PRNGKey(seed)
133
+ rng, init_rng = jax.random.split(rng)
134
+ self.state = create_train_state(init_rng, self.nca, lr, self.dataset.img_shape)
135
+
136
+ bar = tqdm.tqdm(np.arange(num_epochs))
137
+ try:
138
+ for i in bar:
139
+ num_steps = int(np.random.randint(min_steps, max_steps))
140
+ samples, indices = pool.sample(batch_size)
141
+ for j in range(len(samples)):
142
+ if samples[j] is None:
143
+ samples[j] = self.nca.create_seed(
144
+ self.nca.num_hidden_channels,
145
+ self.nca.num_target_channels,
146
+ shape=self.img_shape[:-1],
147
+ batch_size=1,
148
+ )[0]
149
+ samples[0] = self.nca.create_seed(
150
+ self.nca.num_hidden_channels,
151
+ self.nca.num_target_channels,
152
+ shape=self.img_shape[:-1],
153
+ batch_size=1,
154
+ )[0]
155
+ batch = np.stack(samples)
156
+ if self.n_damage > 0:
157
+ damage = (
158
+ 1.0
159
+ - make_circle_masks(
160
+ int(self.n_damage), self.img_shape[0], self.img_shape[1]
161
+ )[..., None]
162
+ )
163
+ batch[-self.n_damage :] *= damage
164
+
165
+ batch = jnp.array(batch)
166
+ targets, rgb_targets = self.dataset.get_batch(batch_size)
167
+ targets = jnp.array(targets)
168
+
169
+ self.state, loss, grads, outputs, rgb_outputs = train_step(
170
+ self.nca.apply,
171
+ self.state,
172
+ batch,
173
+ targets,
174
+ num_steps=num_steps,
175
+ rng=rng,
176
+ )
177
+
178
+ grad_dict = {k: dict(grads[k]) for k in grads.keys()}
179
+ grad_dict = flatten(grad_dict)
180
+
181
+ grad_dict = {
182
+ k: {kk: np.sum(vv).item() for kk, vv in v.items()}
183
+ for k, v in grad_dict.items()
184
+ }
185
+ grad_dict = flatten(grad_dict)
186
+
187
+ pool[indices] = np.array(outputs)
188
+
189
+ bar.set_description("Loss: {}".format(loss.item()))
190
+
191
+ self.emit_metrics(
192
+ writer,
193
+ i,
194
+ batch,
195
+ rgb_outputs,
196
+ rgb_targets,
197
+ loss.item(),
198
+ metrics=grad_dict,
199
+ )
200
+
201
+ return self.state
202
+ except Exception:
203
+ return self.state
204
+
205
+ def emit_metrics(
206
+ self, train_writer, i: int, batch, outputs, targets, loss, metrics={}
207
+ ):
208
+ train_writer.add_scalar("loss", loss, i)
209
+ # train_writer.add_scalar("log10(loss)", math.log10(loss), i)
210
+ train_writer.add_images("batch", self.nca.to_rgb(batch), i, dataformats="NHWC")
211
+ train_writer.add_images("outputs", outputs, i, dataformats="NHWC")
212
+ train_writer.add_images("targets", targets, i, dataformats="NHWC")
213
+ for k in metrics:
214
+ train_writer.add_scalar(k, metrics[k], i)
jax_nca/utils.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+
3
+ import numpy as np
4
+ import PIL.Image
5
+ import requests
6
+
7
+
8
+ def load_image(url, size):
9
+ r = requests.get(url)
10
+ img = PIL.Image.open(io.BytesIO(r.content))
11
+ img.thumbnail((40, 40), PIL.Image.ANTIALIAS)
12
+ img = np.float32(img) / 255.0
13
+ # premultiply RGB by Alpha
14
+ img[..., :3] *= img[..., 3:]
15
+ # pad to self.h, self.h
16
+ diff = size - 40
17
+ img = np.pad(img, ((diff // 2, diff // 2), (diff // 2, diff // 2), (0, 0)))
18
+ return img
19
+
20
+
21
+ def load_emoji(emoji, size, code=None):
22
+ if code is None:
23
+ code = hex(ord(emoji))[2:].lower()
24
+ url = (
25
+ "https://github.com/googlefonts/noto-emoji/blob/main/png/128/emoji_u%s.png?raw=true"
26
+ % code
27
+ )
28
+ return load_image(url, size)
29
+
30
+
31
+ def make_circle_masks(n, h, w, r=None):
32
+ x = np.linspace(-1.0, 1.0, w)[None, None, :]
33
+ y = np.linspace(-1.0, 1.0, h)[None, :, None]
34
+ center = np.random.uniform(-0.5, 0.5, size=[2, n, 1, 1])
35
+ if r is None:
36
+ r = np.random.uniform(0.1, 0.4, size=[n, 1, 1])
37
+ x, y = (x - center[0]) / r, (y - center[1]) / r
38
+ mask = x * x + y * y < 1.0
39
+ return mask.astype(float)
notebooks/Gecko.ipynb ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "source": [
7
+ "import abc\n",
8
+ "import random\n",
9
+ "import numpy as np\n",
10
+ "import matplotlib.pyplot as plt\n",
11
+ "import tensorflow as tf\n",
12
+ "from einops import repeat, rearrange\n",
13
+ "tf.config.experimental.set_visible_devices([], 'GPU')\n",
14
+ "\n",
15
+ "# uncomment this to enable jax gpu preallocation, might lead to memory issues\n",
16
+ "\n",
17
+ "import os\n",
18
+ "os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"false\""
19
+ ],
20
+ "outputs": [],
21
+ "metadata": {}
22
+ },
23
+ {
24
+ "cell_type": "markdown",
25
+ "source": [
26
+ "# Gecko"
27
+ ],
28
+ "metadata": {}
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": 2,
33
+ "source": [
34
+ "from jax_nca.dataset import ImageDataset"
35
+ ],
36
+ "outputs": [],
37
+ "metadata": {}
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": 3,
42
+ "source": [
43
+ "dataset = ImageDataset(emoji='🦎', img_size=64)"
44
+ ],
45
+ "outputs": [],
46
+ "metadata": {}
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": 4,
51
+ "source": [
52
+ "dataset.visualize()"
53
+ ],
54
+ "outputs": [
55
+ {
56
+ "output_type": "display_data",
57
+ "data": {
58
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD7CAYAAACscuKmAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAcMElEQVR4nO3de3RV1Z0H8O8vT8IzQQIEgjwERQYFnCugRUdQLFVb1DrWR0dWS8u00zq22qq0s2a0ra22tdVZa8aWVa1MfeCjtVp0oYhanwWDKAKBgrwDgfAIEB4Jyf3NH/dw9t7H3OSS3HsT2N/PWix+9+5z79nJze+evc/ZZ29RVRDRyS+noytARNnBZCfyBJOdyBNMdiJPMNmJPMFkJ/JEu5JdRKaJyBoRWScid6arUkSUftLW6+wikgvg7wCmAtgK4H0A16vqqvRVj4jSJa8drx0PYJ2qrgcAEZkHYDqApMnep08fHTJkSDt2SUQt2bhxI3bt2iXNlbUn2QcC2GI93gpgQksvGDJkCCoqKtqxSyJqSSwWS1qW8RN0IjJLRCpEpKKmpibTuyOiJNqT7FUABlmPy4PnHKo6R1VjqhorLS1tx+6IqD3ak+zvAxghIkNFpADAdQBeSE+1iCjd2txnV9VGEfk2gJcB5AJ4RFVXpq1mRJRW7TlBB1V9CcBLaaoLEWUQR9AReYLJTuQJJjuRJ5jsRJ5gshN5gslO5AkmO5EnmOxEnmCyE3mCyU7kCSY7kSeY7ESeYLITeYLJTuQJJjuRJ5jsRJ5gshN5gslO5AkmO5EnmOxEnmCyE3mCyU7kCSY7kSeY7ESeYLITeaLVZBeRR0Rkp4issJ7rLSILRWRt8H9JZqtJRO2VypH9UQDTIs/dCWCRqo4AsCh4TESdWKvJrqpvAtgTeXo6gLlBPBfAlemtFhGlW1v77P1UdXsQVwPol6b6EFGGtPsEnaoqAE1WLiKzRKRCRCpqamrauzsiaqO2JvsOESkDgOD/nck2VNU5qhpT1VhpaWkbd0dE7dXWZH8BwIwgngHg+fRUh4gyJZVLb08CeA/AGSKyVURmArgXwFQRWQvgkuAxEXViea1toKrXJym6OM11IaIMajXZKXMS5zYNETnu16X6GiIOlyXyBJOdyBNsxmeZInkTvLbhQBjfv+65MB7Spa+z3cxhZvRyPB53yuz3ZBOfbDyyE3mCyU7kCSY7kSfYZ88wjdw2ELcum8W1ySmb/s5dYfy2rjMFje57LK1dG8b/e87NTlmTmj58LthnJ4NHdiJPMNmJPMFmfIbFI6PkcsV8v95T+YRT9vZR0zwv2G+a43Fx3+O3+koYD17lXpa7Y9SXwthp0gu/133HvwAiTzDZiTzBZnwGxJG8+Vy5b1MY/2zNU05ZTtycnW+w3gPuIDnk1plm/Q82/sEpO6tkWBhfVnZuGNtN+ubqZbNH5UmOOaMf7ZLk2KP1eOa/0+ORncgTTHYiTzDZiTzBPnsGOF3bSFf2/rV/CuOGXu53rd0XR4PVx468R9x6mextcMq+sezBMF7e5zdh3Cuvm1tH++67yA5ycpo/BuTyLroTGo/sRJ5gshN5gs34NIje7GJf1qprPOSULaipCOOc1XvdN2qy3meYaXZrQa77/m/uCOOuc9Y4ZTtvGB7Gj4x+NYxvPfVKZ7uj8UZTj8hluE0Hq83753UJ41X7Nznbjeg+MIzLuyZfE4CX5ToHHtmJPMFkJ/IEk53IE+yzp4HGI/O/W0NMNx91V7uuXr81jItuq3DKUGT65ocfMENddUh3Z7Oc3fVmXwcOu/veaR7P37EkjL87aLqzXX5O8o8+9qqZEKOkwOx7Y3y3s92t1nv+YszXnbJG65xAXgv7ouxJZfmnQSLyuoisEpGVInJL8HxvEVkoImuD/0syX10iaqtUmvGNAG5T1VEAJgL4loiMAnAngEWqOgLAouAxEXVSqaz1th3A9iA+ICKVAAYCmA7gomCzuQDeAHBHRmrZyWkLV5YOHKlzn+hZEIbx4T3csi6mGa9drY+m7qizWcNnB5j3GNrTKYuPMO+5YbfpMkSXmtrVsC+MH9v8mlN2dvHQMF5Xty2MRxUMdLb7/aaXw7hXgTtC7z/OvCGMOYlG53Bcv3kRGQJgHIDFAPoFXwQAUA2gX3qrRkTplHKyi0h3AH8E8B1V3W+XaeKwoUleN0tEKkSkoqampl2VJaK2SynZRSQfiUR/XFWP3cmxQ0TKgvIyADube62qzlHVmKrGSkuTj7Iiosxqtc8uiQXDHgZQqaq/sopeADADwL3B/89npIYnAIm2aaw+fN+8Xk5Rbs/CMD5yz1j3ddb7aL75Hs6PzBufn2fK6se6F0Hsy4CHG8wluuidbO/trgzj7+12J74cfKhrGO8+Yhpx1Udrne3O7DkojMf2GuaUuWvagTqBVC6AfgbAvwD4WEQ+DJ77ARJJ/rSIzASwCcC1GakhEaVFKmfj38an7qgOXZze6hBRpnBoUxpEl0a2L3OVd3fPUwwv7B/Gq3O2O2WoN5eoiqw74Eq6uXe9dc81+6s/7E4kudmK8/PMx9tgjWgDgItLx4bx443fdsp+uv9J87oiU4+BuX2c7V6bdG8Ylxa63Ql70spkk2G0xn4P+/JmtNtkl/HSXnL8zRB5gslO5Ak249Mg2oxvbDJN5vxc91f81WHTwviO9XPd97Ga5DmF5nt4/CF3tdch1u7WlBQ6ZQ0NZtsSMaPpCiI3o9jN3etOvcgpi/U6LYxHPHpJGA8eOsDZzm66f2pe+jY03aOTgCRt/vPsfpvwyE7kCSY7kSeY7ESeYJ89A3JzzaUy+/IRAHxz2BVh/OiGV5yyyqKqMD5idUyv2HjA2e6CI+acwENTy5wysfq9q4/UhvGGg+5lvqHd3NfZPqj6yLzfQTMZRuUnS53t9jSY0XW9C9y771qal97W0iW6dXXm97G8dkMYnxap+6G4GSk48ZQzk+7L94kveWQn8gSTncgTbMZngN1cjF5Osudhf3LCbKfswre/H8b71TSfHzrdbSJX9jIf26hC9/u6MC8/jPfkHwnjS968zdnum8OuDuMNG5c7ZXOWmEuC2sX8LAVwR/LlWKP8opNjpNpitpvuR5vcUX73VJqRfP+3640w7llf4Gxn/753feGZpO+fatfiZMUjO5EnmOxEnmCyE3mCffYMi15OaoibJZZHWxM7AsCi881dZJ9fcncYV6m7JtxWq398eo7b95y2zkxwWWpNZDGvr9unvmuD6ZcfPOwu+5xTbi5txfLNENk/XvaAs11xkZmYozHS386zhgk3NJr3j64r997uVWF8zcf3OWVXFY8P4/EFZghvY747fLj2qPmZJ7/5fafsxUk/DuNueUVhHD2X4kMfnkd2Ik8w2Yk8wWZ8lhXkFCQtG97DNJnfsCaGuOSt7znbVTUeDOMdR9wm7ZXLasO4cKiZy339ae5HXVhv5qJfIflO2b4SM0lFVRfT9J1T/bqz3e3dvhjGPQvdJarsu+AK8pL/zAO7m8k8+nU9xSn7aJ8ZNVd1eFcY16vbZbDvsNvd4I42jNyM5zUe2Yk8wWQn8gSb8RkWndThR6seC+OjkXnhHtlollPadrkZPXZF/0nOdn+pNTfQbDjinlV+6YtmiaaiAvNdHotMR92zqykrzHXLVh0w9dpeb1ahvWfrs249qv8Wxo+OvdUpG9d7RBjPX7sojF/b/K6zXf8upul+a9cLnLL795nfR1VTbRh3EffP9qHRZg69mwZPBTWPR3YiTzDZiTzBZCfyBPvsaRAdjWWLztf+cnVFGO+qd9bHRK88c6msrtHc9XbT4CnOdg9vXRDGy/Pc9++fby5zndndfLynr9rnbHf5e+ZS1iNXlztlxVZfv36HmRhidYO7r+XdtoTxxe/8wCl7aPS/hvGsZ02fen9RPVzWyLVGd2lqOcWcf5DhJj6vabiznd1Pb2ElLu+1emQXkS4iskREPhKRlSJyd/D8UBFZLCLrROQpEUl+MZWIOlwqzfh6AFNUdQyAsQCmichEAPcB+LWqDgewF8DMjNWSiNotlbXeFMCxOw3yg38KYAqAG4Ln5wK4C8BD6a9i52evnAq4N79c9c5dTlnPQtNUP2A11QFgnZr169/dY1ZZndY/5mx3Vr5ZPXUttjhlfazLZt27mMkmSvu488uvnWAueQ3s55YVWstQffkvZu6690rckXa3jjfvsavuoFN2w/u/COPcQlOP3Dr3UmS/7v3CuLx3f6esoJ8ZUdjTWjbr9qFXOdvFYS0TFWnHczkoI9X12XODFVx3AlgI4BMAtarhuMWtAAYmeTkRdQIpJbuqNqnqWADlAMYDGJnqDkRklohUiEhFTU1N6y8goow4rjaOqtYCeB3AeQCKRcKhTOUAqpK8Zo6qxlQ1Vlpa2twmRJQFrfbZRaQUwFFVrRWRIgBTkTg59zqAawDMAzADwPOZrOiJ6nNl5zqP5255NYxX17n9bRSbvvOyfevDOC/SH16x5s0wrh/oXjZblm++v/N3mQkn0bfI2W5Eb9P/HvXmLqds05jiMF45ydwBFy9wjw2f72n+fJZFDhsf1ZsnastM31urNjjbPXy5mbBi2ojJaDdea0sqlevsZQDmikguEi2Bp1V1voisAjBPRH4CYBmAhzNYTyJqp1TOxi8HMK6Z59cj0X8nohMAR9ClgbbQdLywz1nO41srfxfGk/uc7ZStOrg1jH9e+UQYN65e42x3uNFM0CBb3UtZOwcPCeP3rSWgD+9155nru9OMZBv5gTvH3baRZqnnquFmUoq+mw45251jXdqLbXHLlteZS4C/79c1jHfVu8s+jzzFzC0XnXu+KW4m5hDrcqZELq8lXdqZHPwtEXmCyU7kCTbj0yCnhVPA0RVHn/lHs+TT1QPdSSmerno7jK97644w1gb3hhl0M7chPDDhu05R/0FnmPdYauaxO1Tn3sRypJs5G7/iKnc8VFmxKRuz0VwJOPelbc52L88cFsaj9rk3sQytNd2G5UNNV+CD/sXOdq/UrgjjWb1PdcrsprszEo5n3NuER3YiTzDZiTzBZCfyBPvsaSCSvBPZs6Cb89jup8fj7mWza62yPxeb+OUSdwTdjyeayR3/7dybku67e44Zkfe15f/tlP0t11y++6TRrcfIbeZuvO3WBBjV17p96i555lixItbbKWuyfiVnW29fF3cvAf680kzA+aVB/+SU9cw3l+x8X245HXhkJ/IEk53IE2zGZ1h0VFjcao7m5CRvjv7h0p+G8e4L3JVJ+3azbk6JdAXs0XyXDZgQxkuLH3S2++1mMyf7w+sXOGXvHjYj6jYfNaPYNvfIdbabWG/Kyru6f0o9rUtvVy+vDeP603o621X3MCPvntn0V6fsa8MvC+NGay6/vBz+2bYFj+xEnmCyE3mCyU7kCXZ+Mix6WS43xctGuTmmf2z30QH3bjB7uyi7n1vW1X2P6d1Hh/G8rY84ZVsbd4fxJ73NpJIH1J0A44jVZ5+U5x43xlqTVg6sNpNonDbC7bOX55vfxwvb3nDK7D57Do9L7cbfIJEnmOxEnmAzvpNyLtlFWv4tNd2TLUVV1+COwrvoiRvCeH+OWyYHTfO/qNaUHRrgznf3bqkZHZi3113WqUeJGb23//rBYdwUGa1Xtsdcolt6YK1Ttu2QmRtvgNUNif6MHFGXGh7ZiTzBZCfyBJvxnVRLN9e0+DqrSWuPNOte4H7Uv5hiJsf491fvdsrq88xZ9nNGm5tTfnf+bGe712vNElV3rf2NUzbgoHmPyVare9gOd8mrtdbIu7wit4n/Qe06835WMz4eGZWY28bflW94ZCfyBJOdyBNMdiJPsM9+AmppIoea+towvuNjs0hPrPh0Z7vivmbZ5x4DhjllDYf3hPG7cTPJ5NdXuf3yNy80yzJX7neXsnrryCthPK7WTEZ59sp97nbnmr54YaH7s6w5sCmMr8DEMFZ1+/bgsswpSfm3FCzbvExE5gePh4rIYhFZJyJPiUhBa+9BRB3neL4SbwFQaT2+D8CvVXU4gL0AZqazYkSUXik140WkHMDlAO4BcKskrgtNAXBsGNZcAHcBeCgDdaSIlprx260m+NzNC8P4rzuXO9ttPLTDPLBGuwFATo/+Zl9dzPHgnZqPne02HTTvce8Y97t+9CtvhPGSXmaUXO3kfs52DfXmZ8lrcpvnGw5uR7OaHyRIrUj1yP4AgNsBHPs0TgFQq6rHxlVuBTCwmdcRUSfRarKLyBUAdqrq0rbsQERmiUiFiFTU1NS05S2IKA1SObJ/BsAXRGQjgHlINN8fBFAsIse6AeUAqpp7sarOUdWYqsZKS0vTUGUiaotU1mefDWA2AIjIRQC+p6o3isgzAK5B4gtgBoDnM1dNstkTOUQnnDy72FxG+6+RXw7j9/a7yz7bff1t9budsnpr0ovvll8ZxjMHXepsN6jIfHlHl02e3n9yGC85ZCa3zDvo1teeRb4JruojtWiORPZln8OIDqW1l3f2fWnn9vz0dyBxsm4dEn34h1vZnog60HENqlHVNwC8EcTrAYxPf5WIKBM4gu4EF23S2l7ctjiMNxx1T45+c7CZ3+3pLe587fayS78c/fWk79/YZJr7OepeAvzKkM+G8bPvzg/jvHx3u3yrO3Eozy3bD/cOuWMkcunNvkPwU3fAWQ/tLo+PTXr/fmIiTzHZiTzBZvwJL/nyUk+f9x9hvHSPO7/bP3/0szCe2nOMU7bqwOYwXnvAXFE9rVuZs12OdQNK9Cz4WSVDzeu6mptw1h5161HUYF63NzKCrmvczH8XR/Im+J6G/WH84f4NTtnAQrO67Bk9BsFnPLITeYLJTuQJJjuRJ9hnP8F9es5083hwN3OHWb64c81/dcDUMD6/5Eyn7MXqJWFcXGDmho/2le2Ra/aSVIA7yu/GwWZf31jj9tnzjpjXHc11+/1ab+56O9xo5qXff/SQs13stZvDuLrQnQM/77B5z9lDrwnju0bf5Gznw2W5k/OnIqJPYbITeUJUszcTQCwW04qKiqztz3fZbJq2tCRTbUNdGI9c4E5yUdNkLpvlxt0uSZM1J93Cc+8J4wm9z3C2W2rNL790z9+dstsrzQq1cWv03nvn/dLZbnzvkWY7uJcAT6QVZGOxGCoqKpqdSP/E+SmIqF2Y7ESeYLITeYKX3k5idj892qe2h7dGbxSzT+PkWH3vltafi14CbLLmdi8u6B7G3zn9ame7H256LIzz6pwiNFm3t92/9tkwfnHST5ztLiw9q9kYAO5d81QY7847EsZL97qXAJ0+e2TYbk7uyXFMPDl+CiJqFZOdyBNsxnsi2sxucZnjNKyAbDf/7UuAt4yY7mz3xKbXwnhlkTtnaYE1UG5B07IwvnHxvc52Vw48P4z/su1vTtnuxgNhrLlmFGHfwuKkdW9pQpAT2cn5UxHRpzDZiTzBZjxlhH3m3h6lWZTbxdnumfPNBBuXvDPbKdvW3az4mlNnbpiZt+MtZzvncTwyIrTQarrHzVWBKX3HOZvZdcxpqYtzAuORncgTTHYiTzDZiTzBPjtlnD2SL3pHmT0J5Gvnu5fUbv7YrAC+sOFDUxCZe955ywL3+FUmvcL4sXHfD+MSa1QfELlDUE7OY2Cq67NvBHAAieW4GlU1JiK9ATwFYAiAjQCuVdW9makmEbXX8XyFTVbVsaoaCx7fCWCRqo4AsCh4TESdVHua8dMBXBTEc5FYA+6OdtaHTnLRiSDs5vOInuVO2YLPmAkr3qr5OIz/asUAUHPUXKL7h15DnbKr+k0I49Kikmb3C5y8887ZUv0JFcArIrJURGYFz/VT1WMzAlYD6Nf8S4moM0j1yD5JVatEpC+AhSKy2i5UVRWJLreXEHw5zAKAU089tV2VJaK2S+nIrqpVwf87ATyHxFLNO0SkDACC/3cmee0cVY2paqy0tDQ9tSai49bqkV1EugHIUdUDQXwpgB8BeAHADAD3Bv8/n8mK0snJuSwX6UdLjrnEdoE1KcUFkQkqUtXSenE+SKUZ3w/Ac8FY5zwAT6jqAhF5H8DTIjITwCYA12aumkTUXq0mu6quBzCmmed3A7g4E5UiovTjCDrqNDLdtD6R5n/PBL9/eiKPMNmJPMFkJ/IEk53IE0x2Ik8w2Yk8wWQn8gSTncgTTHYiTzDZiTzBZCfyBJOdyBNMdiJPMNmJPMFkJ/IEk53IE0x2Ik8w2Yk8wWQn8gSTncgTTHYiTzDZiTzBZCfyBJOdyBMpJbuIFIvIsyKyWkQqReQ8EektIgtFZG3wf0nr70REHSXVI/uDABao6kgkloKqBHAngEWqOgLAouAxEXVSrSa7iPQCcCGAhwFAVRtUtRbAdABzg83mArgyM1UkonRI5cg+FEANgN+LyDIR+V2wdHM/Vd0ebFONxGqvRNRJpZLseQDOAfCQqo4DcBCRJruqKgBt7sUiMktEKkSkoqampr31JaI2SiXZtwLYqqqLg8fPIpH8O0SkDACC/3c292JVnaOqMVWNlZaWpqPORNQGrSa7qlYD2CIiZwRPXQxgFYAXAMwInpsB4PmM1JCI0iLV9dlvBvC4iBQAWA/gK0h8UTwtIjMBbAJwbWaqSETpkFKyq+qHAGLNFF2c1toQUcZwBB2RJ5jsRJ5gshN5gslO5AkmO5EnmOxEnmCyE3lCEsPas7QzkRokBuD0AbAraztuXmeoA8B6RLEeruOtx2BVbXZcelaTPdypSIWqNjdIx6s6sB6sRzbrwWY8kSeY7ESe6Khkn9NB+7V1hjoArEcU6+FKWz06pM9ORNnHZjyRJ7Ka7CIyTUTWiMg6EcnabLQi8oiI7BSRFdZzWZ8KW0QGicjrIrJKRFaKyC0dURcR6SIiS0Tko6AedwfPDxWRxcHn81Qwf0HGiUhuML/h/I6qh4hsFJGPReRDEakInuuIv5GMTduetWQXkVwA/wPgcwBGAbheREZlafePApgWea4jpsJuBHCbqo4CMBHAt4LfQbbrUg9giqqOATAWwDQRmQjgPgC/VtXhAPYCmJnhehxzCxLTkx/TUfWYrKpjrUtdHfE3krlp21U1K/8AnAfgZevxbACzs7j/IQBWWI/XACgL4jIAa7JVF6sOzwOY2pF1AdAVwAcAJiAxeCOvuc8rg/svD/6ApwCYD0A6qB4bAfSJPJfVzwVALwAbEJxLS3c9stmMHwhgi/V4a/BcR+nQqbBFZAiAcQAWd0Rdgqbzh0hMFLoQwCcAalW1MdgkW5/PAwBuBxAPHp/SQfVQAK+IyFIRmRU8l+3PJaPTtvMEHVqeCjsTRKQ7gD8C+I6q7u+Iuqhqk6qOReLIOh7AyEzvM0pErgCwU1WXZnvfzZikqucg0c38lohcaBdm6XNp17TtrclmslcBGGQ9Lg+e6ygpTYWdbiKSj0SiP66qf+rIugCAJlb3eR2J5nKxiByblzAbn89nAHxBRDYCmIdEU/7BDqgHVLUq+H8ngOeQ+ALM9ufSrmnbW5PNZH8fwIjgTGsBgOuQmI66o2R9KmwRESSW0apU1V91VF1EpFREioO4CInzBpVIJP012aqHqs5W1XJVHYLE38NrqnpjtushIt1EpMexGMClAFYgy5+LZnra9kyf+IicaLgMwN+R6B/+MIv7fRLAdgBHkfj2nIlE33ARgLUAXgXQOwv1mIREE2w5gA+Df5dluy4AzgawLKjHCgD/GTw/DMASAOsAPAOgMIuf0UUA5ndEPYL9fRT8W3nsb7OD/kbGAqgIPps/AyhJVz04go7IEzxBR+QJJjuRJ5jsRJ5gshN5gslO5AkmO5EnmOxEnmCyE3ni/wETE8t04/X9pgAAAABJRU5ErkJggg==",
59
+ "text/plain": [
60
+ "<Figure size 432x288 with 1 Axes>"
61
+ ]
62
+ },
63
+ "metadata": {
64
+ "needs_background": "light"
65
+ }
66
+ }
67
+ ],
68
+ "metadata": {}
69
+ },
70
+ {
71
+ "cell_type": "code",
72
+ "execution_count": 5,
73
+ "source": [
74
+ "from jax_nca.nca import NCA"
75
+ ],
76
+ "outputs": [],
77
+ "metadata": {}
78
+ },
79
+ {
80
+ "cell_type": "markdown",
81
+ "source": [
82
+ "### NCA\n",
83
+ "- num_hidden_channels = 16\n",
84
+ "- num_target_channels = 3\n",
85
+ "- cell_fire_rate = 1.0 (100% chance for cells to be updated)\n",
86
+ "- alpha_living_threshold = 0.1 (threshold for cells to be alive)"
87
+ ],
88
+ "metadata": {}
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": 6,
93
+ "source": [
94
+ "nca = NCA(16, 3, trainable_perception=False, cell_fire_rate=1.0, alpha_living_threshold=0.1)"
95
+ ],
96
+ "outputs": [],
97
+ "metadata": {}
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": 7,
102
+ "source": [
103
+ "from jax_nca.trainer import EmojiTrainer"
104
+ ],
105
+ "outputs": [],
106
+ "metadata": {}
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "execution_count": 8,
111
+ "source": [
112
+ "trainer = EmojiTrainer(dataset, nca, n_damage=0)"
113
+ ],
114
+ "outputs": [],
115
+ "metadata": {}
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": 9,
120
+ "source": [
121
+ "# trainer.train(100000, batch_size=8, seed=10, lr=2e-4, min_steps=64, max_steps=96)"
122
+ ],
123
+ "outputs": [],
124
+ "metadata": {}
125
+ },
126
+ {
127
+ "cell_type": "markdown",
128
+ "source": [
129
+ "#### Get current state from trainer"
130
+ ],
131
+ "metadata": {}
132
+ },
133
+ {
134
+ "cell_type": "code",
135
+ "execution_count": 10,
136
+ "source": [
137
+ "state = trainer.state"
138
+ ],
139
+ "outputs": [],
140
+ "metadata": {}
141
+ },
142
+ {
143
+ "cell_type": "code",
144
+ "execution_count": 11,
145
+ "source": [
146
+ "# save\n",
147
+ "# nca.save(state.params, \"saved_params\")"
148
+ ],
149
+ "outputs": [],
150
+ "metadata": {}
151
+ },
152
+ {
153
+ "cell_type": "code",
154
+ "execution_count": 12,
155
+ "source": [
156
+ "params = nca.load(\"gecko_100_cell_fire_rate\")"
157
+ ],
158
+ "outputs": [],
159
+ "metadata": {}
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": 13,
164
+ "source": [
165
+ "import numpy as np\n",
166
+ "import matplotlib\n",
167
+ "import matplotlib.pyplot as plt\n",
168
+ "%matplotlib ipympl\n",
169
+ "\n",
170
+ "plt.style.use('ggplot')\n",
171
+ "# Imports specifically so we can render outputs in Jupyter.\n",
172
+ "from JSAnimation.IPython_display import display_animation\n",
173
+ "from matplotlib import animation\n",
174
+ "from IPython.display import display\n",
175
+ "from celluloid import Camera\n",
176
+ "from IPython.display import HTML\n",
177
+ "import jax\n",
178
+ "import jax.numpy as jnp\n",
179
+ "\n",
180
+ "def render_nca_steps(nca, params, shape = (64, 64), num_steps = 2):\n",
181
+ " nca_seed = nca.create_seed(nca.num_hidden_channels, nca.num_target_channels, shape=shape, batch_size=1)\n",
182
+ " rng = jax.random.PRNGKey(0)\n",
183
+ " _, outputs = nca.multi_step(params, nca_seed, rng, num_steps=num_steps)\n",
184
+ " stacked = jnp.squeeze(jnp.stack(outputs))\n",
185
+ " rgbs = np.array(nca.to_rgb(stacked))\n",
186
+ "\n",
187
+ " fig = plt.figure(\"Animation\",figsize=(7,5))\n",
188
+ " camera = Camera(fig)\n",
189
+ " ax = fig.add_subplot(111)\n",
190
+ " frames = []\n",
191
+ " for r in rgbs:\n",
192
+ " frame = ax.imshow(r)\n",
193
+ " ax.axis('off')\n",
194
+ " camera.snap()\n",
195
+ " frames.append([frame])\n",
196
+ " animation = camera.animate(blit=False, interval=50)\n",
197
+ " animation.save('gecko.mp4')\n",
198
+ " return animation, outputs, rgbs"
199
+ ],
200
+ "outputs": [],
201
+ "metadata": {}
202
+ },
203
+ {
204
+ "cell_type": "code",
205
+ "execution_count": 14,
206
+ "source": [
207
+ "animation, outputs, rgbs = render_nca_steps(nca, params, num_steps=256)"
208
+ ],
209
+ "outputs": [
210
+ {
211
+ "output_type": "display_data",
212
+ "data": {
213
+ "application/vnd.jupyter.widget-view+json": {
214
+ "version_major": 2,
215
+ "version_minor": 0,
216
+ "model_id": "c0e2eeba79a046e1a9e1b56275e1c911"
217
+ },
218
+ "text/html": [
219
+ "\n",
220
+ " <div style=\"display: inline-block;\">\n",
221
+ " <div class=\"jupyter-widgets widget-label\" style=\"text-align: center;\">\n",
222
+ " Animation\n",
223
+ " </div>\n",
224
+ " <img src='' width=700.0/>\n",
225
+ " </div>\n",
226
+ " "
227
+ ],
228
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAArwAAAH0CAYAAADfWf7fAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAZiElEQVR4nO3da5CdB33f8efs2dtZ7a7usmTJsnzD2JDY1Da1zd1JS4ZJgdIEDDEeksk4aWBat9M2tA2UAiE0wJRpMi6lbQopbV2mIVxKU8CkwfgSjI2NjS2IL9gW1sWyZGm1q72ePX3RN5nJ78lEjF3Zf30+L7+Wzh6tNKufnpn9uzMYDAYNAAAUNXSy3wAAADybDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKG34ZL8BgD+vvzqIfXGwFPtQy7/bx4ZGYu90frz3BcDzlye8AACUZvACAFCawQsAQGkGLwAApRm8AACU5koDcFKstPSlzmrs777v9/JPGOQf/8EX/1LsQ5387/zJ7ljLOwLg+c4TXgAASjN4AQAozeAFAKA0gxcAgNIMXgAASnOlAXhWtV1jaPvX9q/c+dHYP7P/T2IfHR6P/dbjfxb71y77QOyrLVcajvTnYt/QXRM7AM89nvACAFCawQsAQGkGLwAApRm8AACUZvACAFBaZzAYDE72mwDqWm3pn997a+xv/85HYh8aH4199vhs/gDD+d/zb97+stjff94vxH5+b2d+/U7OADz3eMILAEBpBi8AAKUZvAAAlGbwAgBQmsELAEBprjQAz4iFJn8pWe0vx77tK2/LrzPSctdhqR9zfyT/u31qJff5lrsRn7zonbG/ZfsrY2+7PtFr8jUJAE4eT3gBACjN4AUAoDSDFwCA0gxeAABKM3gBACht+GS/AaCG8aYT+w17bop9rr8Y++hy/nf4Qn+p5ePmqwgLs/OxD031Yn/Xdz8R+1WnXRz7juFNsZ+o1ZbrFu3y5/lEf/SJvQrA85snvAAAlGbwAgBQmsELAEBpBi8AAKUZvAAAlOZKA3BC2m4KtPXfefCLsXefXol9ea7lesO68dhXWq49NE/MxDz2hfvy6/zk5tg/tutzsf/mhdfGPtHk97na9GM/unI89pnF3A8uHYt94/ja2Hf18q8L4FTiCS8AAKUZvAAAlGbwAgBQmsELAEBpBi8AAKW50gCckLZrDI/O7Yv90OLh2Dddf0fss6PLsS/9k4vz+zlzIvbukXzlYHkmXzno7l6N/SsP3xb7Ry68Lvamn19nKB9paC7941/LLzPWjX1hkK9bXLv11bF/6MJfzu8nv81myGMQoCBf2gAAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM2VBuCEdFquBBzr5H5w4enYh8fzj1+/bWPsM+vyl6uhybHcLz0t9qWWqw7NYj6jcLRZyD++xWx3KfbvHn4w9jPGN8T+xEr+vI22fNn+T4/flF9n7qnY/8Ml/yj2nr8WgII84QUAoDSDFwCA0gxeAABKM3gBACjN4AUAoDTfjguckEHLl401w73Yx3vTsS9/+lWxH39qkD/wwvGYp1YWYx9bl68xHO7kl+8fz9cV5non9lzgwaUDsf/yff869r1PPxn7Uid/Hpa7+eNuGpmK/Q2bXxr7kMcdwCnElzwAAEozeAEAKM3gBQCgNIMXAIDSDF4AAEpzpQE4IYOW6wGnjeRrDOeNboz94dHZ2Ec3zsc+cij/+/y01Zibbj//h3Xr8/WGx1bz+Ya18/njzi7PxX7R6I7Yb7z412P/4Lf/fexfPHxn7KdNbYn9ziv/TezrpzfE3l1oOVcxnjPA85knvAAAlGbwAgBQmsELAEBpBi8AAKUZvAAAlOZKA3BCuk3+7v5eyzf9v3HbFbF/YP/n8+sP9WJftzlfXXjJzELsE938hh5r+rF3pvN5gtkjy7FPrXTz64zkKxYXbDg79s+88n2x7/jPr4t9op/fz+a1m2If7rQ813iWrzHMDfL7HGp5P6NN/ny2PZVp+ePWDFbzn5POkOc7cCrzFQAAgNIMXgAASjN4AQAozeAFAKA0gxcAgNJcaQCeEcP9/O/nd537+tj/YO/NsT+wvC/2pZbvsr+kpV9xcCV/3DPyNYD13dzvH12K/YHlJ2K/sHdW7J0mXy24Zc/dsc89fTT2wWr+9R6YOxT79snNsT/bhlrOKDw8uyf2R+f3xz45nK92LK7k35dXrntx7L2W6x/AqcETXgAASjN4AQAozeAFAKA0gxcAgNIMXgAASusMBoP8P34HeAYMmvwl5pGj+crBxTf/auz9yfz6Fyytxn71lqnYp2dmYt/fcu3h7mY89kdX18R+/Xlvjf3+Pbtj/8yt/yP2g8fy1YXpiXWxP3T912LfMDId+4ma7y/GPtTNx34WmnxF4bo7Pxb7l578Vv7Ag/z7u66ZiP3el98Q+8bpTbF3Oi3nJIBSPOEFAKA0gxcAgNIMXgAASjN4AQAozeAFAKA0VxqAZ9VKS58Z5O/iPzC7N/Yrbn5n7CNNvh5wxbpe7H9nNL+fC/fNx37PmZtj/29z+Vf2wIG52OdH8rWBxaePxX7OsXyW4n++Pl8hOGPrzthXhvMVgvEmX1doM+jnvyq+fuCO2P/2t9+X+6aXxb5n9XDsc4v592W25QrEttV8neMLr/pg7JNjLec/gFI84QUAoDSDFwCA0gxeAABKM3gBACjN4AUAoLQT+zZdgBZt517ybYKmmejk6wFnTm2P/ZbXfCL219x0Xez7+vmKwsGRNbGPHM/XAF53x2OxP3rJ6bEvr81XAh6Zn419Zku+JvHIjunYP71wc+zXPP3q2LdtPjN/3JYrB5Or+a+Ffjf/Du+Y3BL7lt6G2O/q58/n4YV8rWIwyH9OxkbzuY29zUzsQ0Oe78CpzFcAAABKM3gBACjN4AUAoDSDFwCA0gxeAABK6wwGg7Zvrgb4K+u39Pff/6nYx5v83ffvv/+/xP7km78c+wd2/37sHz/4udivXTMW+yXd/CtYu2Uy9qPH5mKfnR2J/fa55djvWcx9ppuvEBwfyZ+3yXycoPnKlf8q9rN7W2O/+cDdsf/xnjtiP218few71uTX//juP4j9e8P7Y1/Kn4bmN85+S+z/4Iw3xD7VWxv7cMufQ6AWT3gBACjN4AUAoDSDFwCA0gxeAABKM3gBACgt/0/TAVqs9PNVgZVOvnLw2Yf/d+zHVuZjP723Lvbh1YXYrz7rqtg/vTd/3HuO5dfprc3frX/hYv51vfR7h2M///F8LmHyyrNy7+QvwwtPHo/94al8ZeLekaXY3/Sdfxb7h8/6u7H//T/8h7EfGMrvZ2J4OvZObyr25S357MLSaP78XzZ9QezX78jXGCbG81UN1xjg1OYJLwAApRm8AACUZvACAFCawQsAQGkGLwAApbnSAJyQoU439jfd8huxbxxbG/uh5dnY9/WPxP6NfffG/tPbr4x9Wz9fD3h0ZSX2DcfzNYbtE4PYn1o7EfvwRfnjnnVuy49fGYn9Z277Qey7N+QrBO95Yb5+sKflGsZbH/hw7J1mMfZmIV/nmBjkz9v568+I/chwfs7S6+Vf1zVbXpE/7lgv9rGWqxfAqc0TXgAASjN4AQAozeAFAKA0gxcAgNIMXgAASvPtrMCJyUcLmtfvfFXsv/vg52Kfy0cFmsHxfCXgS0/cHvtQpxP7Az/6fuy9Tevyj5/I3/U/fSRfdeidPhX77Pqx2F9w31OxL1+6I/a9l26Mfe14/g147XS+cnDXbL7S8O2hfB1i+bStsc88cSj2G3/+htiv2PSi2Jux/HG7w7kPNfn3d6jreQ3wV+crBgAApRm8AACUZvACAFCawQsAQGkGLwAApXUGg0HL91wD/EWry/3YH+/n7+L/iZt+JfaXTp8T+91P7Y796bn8+pOzq7Ef7+brCpu27oz9qen869qwmvtPTOQzE780nI/fXP3FP4v9a792fuzzs/k6wZbF/JziaMvji6Xd+fP2neX8pf9TW9fEfmjfsdi/9ap8peGcjWfEPjGaXx/g2eQJLwAApRm8AACUZvACAFCawQsAQGkGLwAApRm8AACUlu/nALQZ6ca8uTMV+yde8vdi/7nTXhb7v33oD2N/93f/XX4/q7M5H5+L/UOv+YXYN27K58re/KcfjP3Ohfxxhyd7se/+2XyGbXIlfxm+qJ9f/6Iv7In9lp//ydhf0clnwDa2nJf707F8bu0HuzbHftfCQ7G/oHtm7AAngye8AACUZvACAFCawQsAQGkGLwAApRm8AACU1hkMBoOT/SaA57/FwUrsQ51O7EuDfCVgtJ//HX7NHflawm0H7439F7f9zdjfffE1+f108/WJe44+GPvPfOO9sfd25WsVEzMHY3/h6HTs5w4txn7l6ELsm8dOj70/nT/Pex85Gvuefv79ums0X2842l8X+22vuCH21ZY/D0NN/vwDPBM84QUAoDSDFwCA0gxeAABKM3gBACjN4AUAoDRXGoDnlJUmf0labekr/eXY265DjA3lawP9Qb6KsLqanwscHByL/cYHvhj7R777+7Ef3ZLfz+n56EVz4fa1sb+ytxT7mWO92LvH8uftnN3zsX9y03Ds3xzP7/+f77wu9rfs+KnYO03+/QJ4JnjCCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaa40APw4VvKVg73H98f+6hvfEfuDg4OxD593Ruyblruxv2IqP7/46XX5SsMZx/NVihc90Y/966P54/7XdfmvkInBubF//q//ZuydoZHYm5ZrGwAnwhNeAABKM3gBACjN4AUAoDSDFwCA0gxeAABKy/9zdAD+Usv9pdjf/NVfj/1HK8di787n1+kdOB771m35+sE3534U+0LLkYM3rs/POxbOH419eF++6rB2Il9XuPfwI7E/uvhk7LsmdsTuRgPwTPCEFwCA0gxeAABKM3gBACjN4AUAoDSDFwCA0lxpAGiaZqFZjX11OV9LmBlaiP2t57429rtn8xWF/syh2C/YcFbsn738fbHfPv947O958OOxr5+dj/1vjfZjn1g7EfvUsfz5aXr5ecp9R34Y+66J7fl1BjnPr+arEb3ueP4JwCnNE14AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEpzpQGgaZql/lLsb7/rw7H/3PaXxb5pakvs3bF8bmD9ztNjv395f+xvuPVfxn7bT90Q+94zron9U9//ROyXr52O/W/88Ejs/2fjaOzjw53Y7zv2aOyvb14e+5eevDP2l/byFYvetCsNwF/kCS8AAKUZvAAAlGbwAgBQmsELAEBpBi8AAKW50gDQNM3KoB/7zfvuif3x2Sdjf+ypH8be6Y3EvjzI1xsW8w9vHmq53jDf5CsT1257bewfvf8/xv7thYX8gU/Lf10MtTw3mVjIn88fze7Nr9/isqlzY984vOaEXgc4tXnCCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaa40ADRNMzo8Fvs/ffHbY/9fB+6Ifc3a9bE/sXwo9omJjbH/9tlvi/3qs14d+1jLl/ORfu5XT14e+5cHt+XXGYzGvnRsLvbVtfmKwtzqfOz9wWrsWyfy5xPgRHjCCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaa40ADRNM9J0Y//ko38U+97FfHXh+nPeEPt/f+KbsZ89vTP2q6ZfHHt3sRP75Hi+orDazdcPrr3o6thvvOkbsd+/Pj8f2bBxbewjy8uxz3fy6wwt9WNfHs2/L0urS7GPdvOPb/v9PbaUr0xMjeYrE8Dzkye8AACUZvACAFCawQsAQGkGLwAApRm8AACU5koDQNM0nWYQ+5+8/GOx33b0/tjfcedHY79o/dmx75s7EPvOqa2x98bGY2+z1PLruqC3K/Zdo9tif7h/JPa9M0/HvjIxld/Pkf2xz4/F3Ay3vP8DC/lKxg8W98X+wvHTY98+sSn2hdV8ZaJpmmZ8aKT1vwHPTZ7wAgBQmsELAEBpBi8AAKUZvAAAlGbwAgBQmisNAE3TDOdjAM3IWP4yeV4vf9f/G7ddGfvOsc2xH1o+lj9utxv7eHNiFwLGm/w6yy3XD9552bWx/+Jtvx37+pbHJjNzx2MfDD8V+/Hl+diP9g/Hfu6t1+UPPNLy19pS/vVev/11sX/oRe/Ir9M0TX+QX6vb6bT+HODk8oQXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSOoNBy7ebAtAsDJZj77R8R/5Q09bz84W2L8BtX5lHnqFLACstH3m56cd+4dffHvvMfL4ysbK0FPvi+Hjsn7/8vbFfPnl+7N9f3Bf7LYd2x/6PH/y92JuV/Ov96sX/Iv/4pmmu2npp7K40wHOXJ7wAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJTW8j8dB6Bpmma8M3JyPvCz/A3/w63XJPJfCx99ybtif9utvxX7+jVTsc8cz1cdbtjzR7Hv2Lkh9sunzo790u3nxP5bD3829sNDM7Hf8/TDsTdN01y19ZKW/+JKAzxXecILAEBpBi8AAKUZvAAAlGbwAgBQmsELAEBprjQA8Of0Y/3ZDZfF/qYzr4j9xv3fir3bz5cMvvTYrbF38ttpdnXWxn5g7kjswy0HFDr9Qexbevk6RNM0TWeQf44jDfDc5QkvAAClGbwAAJRm8AIAUJrBCwBAaQYvAACldQaDtm83BYD/59hgPvZOJ58muOKrvxr7QyuHY184nl9/bHIy9olON/a5ZiX2TsvfdC/obYv9yy95b/4JTdNsndwS+8hQfk/AyecJLwAApRm8AACUZvACAFCawQsAQGkGLwAApbnSAMCPbbHlKsKhxUOxv+n298f+vYX9sc/NHYl9TS9fb+jmoxHN5qG1sX/hkvfEfu76XfmFmqYZ7o7kj936M4CTzRNeAABKM3gBACjN4AUAoDSDFwCA0gxeAABKc6UBgB/bwmAh9k4n3ywYabllcPfRR2Lff2hf7Lcv5B9/3uTpsb9uw1+LfcOaDbH/ZX8zDg+1nIIAnrM84QUAoDSDFwCA0gxeAABKM3gBACjN4AUAoDRXGgD4/6bf9GNfbfmbaGZ1PvZ1q2P59bv5gsJwy3WIIRcX4JTgCS8AAKUZvAAAlGbwAgBQmsELAEBpBi8AAKW50gAAQGme8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaf8XdahW29FwEN8AAAAASUVORK5CYII=",
229
+ "text/plain": [
230
+ "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
231
+ ]
232
+ },
233
+ "metadata": {}
234
+ }
235
+ ],
236
+ "metadata": {}
237
+ },
238
+ {
239
+ "cell_type": "code",
240
+ "execution_count": 15,
241
+ "source": [
242
+ "from IPython.display import Video\n",
243
+ "\n",
244
+ "Video(\"gecko.mp4\")"
245
+ ],
246
+ "outputs": [
247
+ {
248
+ "output_type": "execute_result",
249
+ "data": {
250
+ "text/html": [
251
+ "<video src=\"gecko.mp4\" controls >\n",
252
+ " Your browser does not support the <code>video</code> element.\n",
253
+ " </video>"
254
+ ],
255
+ "text/plain": [
256
+ "<IPython.core.display.Video object>"
257
+ ]
258
+ },
259
+ "metadata": {},
260
+ "execution_count": 15
261
+ }
262
+ ],
263
+ "metadata": {}
264
+ }
265
+ ],
266
+ "metadata": {
267
+ "orig_nbformat": 4,
268
+ "language_info": {
269
+ "name": "python",
270
+ "version": "3.9.0",
271
+ "mimetype": "text/x-python",
272
+ "codemirror_mode": {
273
+ "name": "ipython",
274
+ "version": 3
275
+ },
276
+ "pygments_lexer": "ipython3",
277
+ "nbconvert_exporter": "python",
278
+ "file_extension": ".py"
279
+ },
280
+ "kernelspec": {
281
+ "name": "python3",
282
+ "display_name": "Python 3.9.0 64-bit ('jax_gpu': conda)"
283
+ },
284
+ "interpreter": {
285
+ "hash": "a7271dcc4a91420ffb9cc5ce7ff5a5d83d948f729c0ba20dec48f9a748a86390"
286
+ }
287
+ },
288
+ "nbformat": 4,
289
+ "nbformat_minor": 2
290
+ }
notebooks/gecko.mp4 ADDED
Binary file (66.5 kB). View file
notebooks/gecko_100_cell_fire_rate ADDED
Binary file (37.5 kB). View file
setup.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import re
4
+ from os import path
5
+
6
+ from setuptools import find_packages
7
+ from setuptools import setup
8
+
9
+
10
+ this_directory = path.abspath(path.dirname(__file__))
11
+ with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f:
12
+ long_description = f.read()
13
+
14
+
15
+ setup(
16
+ name="jax_nca",
17
+ version="0.1.2",
18
+ url="https://github.com/shyamsn97/jax-nca",
19
+ license='MIT',
20
+
21
+ author="Shyam Sudhakaran",
22
+ author_email="shyamsnair@protonmail.com",
23
+
24
+ description="Neural Cellular Automata (https://distill.pub/2020/growing-ca/ -- Mordvintsev, et al., \"Growing Neural Cellular Automata\", Distill, 2020) implemented in JAX",
25
+
26
+ long_description=long_description,
27
+ long_description_content_type="text/markdown",
28
+
29
+ packages=find_packages(exclude=('tests',)),
30
+
31
+ install_requires=[
32
+ 'numpy',
33
+ 'jax',
34
+ 'flax',
35
+ 'matplotlib',
36
+ 'einops',
37
+ 'tensorflow',
38
+ 'tensorboardX',
39
+ 'optax',
40
+ 'tqdm',
41
+ 'pandas',
42
+ 'pillow',
43
+ 'ipycanvas',
44
+ 'orjson',
45
+ 'opencv-python'
46
+ ],
47
+
48
+ classifiers=[
49
+ 'Development Status :: 2 - Pre-Alpha',
50
+ 'License :: OSI Approved :: MIT License',
51
+ 'Programming Language :: Python :: 3.9',
52
+ ],
53
+ )