shyamsn97
commited on
Commit
•
434b57f
1
Parent(s):
4b71f88
first commit
Browse files- .flake8 +11 -0
- .gitignore +53 -0
- LICENSE +21 -0
- Makefile +28 -0
- README.md +151 -0
- images/gecko.gif +0 -0
- images/model.svg +0 -0
- jax_nca/__init__.py +0 -0
- jax_nca/dataset.py +36 -0
- jax_nca/nca.py +203 -0
- jax_nca/trainer.py +214 -0
- jax_nca/utils.py +39 -0
- notebooks/Gecko.ipynb +290 -0
- notebooks/gecko.mp4 +0 -0
- notebooks/gecko_100_cell_fire_rate +0 -0
- setup.py +53 -0
.flake8
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[flake8]
|
2 |
+
ignore = E203, E266, E501, W503, C901 E731
|
3 |
+
# line length is intentionally set to 80 here because black uses Bugbear
|
4 |
+
# See https://github.com/psf/black/blob/master/README.md#line-length for more details
|
5 |
+
max-line-length = 1000
|
6 |
+
max-complexity = 18
|
7 |
+
select = B,C,E,F,W,T4,B9
|
8 |
+
# We need to configure the mypy.ini because the flake8-mypy's default
|
9 |
+
# options don't properly override it, so if we don't specify it we get
|
10 |
+
# half of the config from mypy.ini and half from flake8-mypy.
|
11 |
+
mypy_config = mypy.ini
|
.gitignore
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.py[cod]
|
2 |
+
|
3 |
+
# C extensions
|
4 |
+
*.so
|
5 |
+
|
6 |
+
# Packages
|
7 |
+
*.egg
|
8 |
+
*.egg-info
|
9 |
+
dist
|
10 |
+
build
|
11 |
+
eggs
|
12 |
+
parts
|
13 |
+
bin
|
14 |
+
var
|
15 |
+
sdist
|
16 |
+
develop-eggs
|
17 |
+
.installed.cfg
|
18 |
+
lib
|
19 |
+
lib64
|
20 |
+
__pycache__
|
21 |
+
|
22 |
+
# Installer logs
|
23 |
+
pip-log.txt
|
24 |
+
|
25 |
+
# Unit test / coverage reports
|
26 |
+
.coverage
|
27 |
+
.tox
|
28 |
+
nosetests.xml
|
29 |
+
|
30 |
+
# Translations
|
31 |
+
*.mo
|
32 |
+
|
33 |
+
# Mr Developer
|
34 |
+
.mr.developer.cfg
|
35 |
+
.project
|
36 |
+
.pydevproject
|
37 |
+
test.json
|
38 |
+
*.pickle
|
39 |
+
venv
|
40 |
+
.idea
|
41 |
+
*.vscode/
|
42 |
+
|
43 |
+
#notebooks
|
44 |
+
*/**/.ipynb_checkpoints/*
|
45 |
+
|
46 |
+
# logs
|
47 |
+
*/**/checkpoints/*
|
48 |
+
*/**/mlruns/*
|
49 |
+
*/**/tensorboard_logs/*
|
50 |
+
*/**/wandb/*
|
51 |
+
checkpoints/*
|
52 |
+
wandb/*
|
53 |
+
mlruns/*
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
The MIT License (MIT)
|
2 |
+
|
3 |
+
Copyright (c) 2021 Shyam Sudhakaran
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
Makefile
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
clean: clean-build clean-pyc clean-test ## remove all build, test, coverage and Python artifacts
|
2 |
+
|
3 |
+
clean-build: ## remove build artifacts
|
4 |
+
rm -fr build/
|
5 |
+
rm -fr dist/
|
6 |
+
rm -fr .eggs/
|
7 |
+
find . -name '*.egg-info' -exec rm -fr {} +
|
8 |
+
find . -name '*.egg' -exec rm -f {} +
|
9 |
+
|
10 |
+
clean-pyc: ## remove Python file artifacts
|
11 |
+
find . -name '*.pyc' -exec rm -f {} +
|
12 |
+
find . -name '*.pyo' -exec rm -f {} +
|
13 |
+
find . -name '*~' -exec rm -f {} +
|
14 |
+
find . -name '__pycache__' -exec rm -fr {} +
|
15 |
+
|
16 |
+
clean-test: ## remove test and coverage artifacts
|
17 |
+
rm -fr .tox/
|
18 |
+
rm -f .coverage
|
19 |
+
rm -fr coverage/
|
20 |
+
rm -fr .pytest_cache
|
21 |
+
|
22 |
+
lint: ## check style with flake8
|
23 |
+
isort --profile black jax_nca
|
24 |
+
black jax_nca
|
25 |
+
flake8 jax_nca
|
26 |
+
|
27 |
+
install: clean lint
|
28 |
+
python setup.py install
|
README.md
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Neural Cellular Automata (Based on https://distill.pub/2020/growing-ca/) implemented in Jax (Flax)
|
2 |
+
|
3 |
+
![Gecko gif](https://raw.githubusercontent.com/shyamsn97/jax-nca/main/images/gecko.gif?token=GHSAT0AAAAAABTB4G7FLAJSLDHSIOQONS3IYTB5ZEA)
|
4 |
+
|
5 |
+
---
|
6 |
+
|
7 |
+
|
8 |
+
## Installation
|
9 |
+
from source:
|
10 |
+
```
|
11 |
+
git clone git@github.com:shyamsn97/jax-nca.git
|
12 |
+
cd jax-nca
|
13 |
+
python setup.py install
|
14 |
+
```
|
15 |
+
|
16 |
+
from PYPI
|
17 |
+
```
|
18 |
+
pip install jax-nca
|
19 |
+
```
|
20 |
+
---
|
21 |
+
|
22 |
+
## How do NCAs work?
|
23 |
+
For more information, view the awesome article https://distill.pub/2020/growing-ca/ -- Mordvintsev, et al., "Growing Neural Cellular Automata", Distill, 2020
|
24 |
+
|
25 |
+
Image below describes a single update step: https://github.com/distillpub/post--growing-ca/blob/master/public/figures/model.svg
|
26 |
+
|
27 |
+
![NCA update](https://raw.githubusercontent.com/shyamsn97/jax-nca/main/images/model.svg?token=GHSAT0AAAAAABTB4G7FOWOPXEUYVLBGRNSWYTB5YUA)
|
28 |
+
|
29 |
+
---
|
30 |
+
|
31 |
+
## Why Jax?
|
32 |
+
|
33 |
+
<b> Note: This project served as a nice introduction to jax, so its performance can probably be improved </b>
|
34 |
+
|
35 |
+
NCAs are autoregressive models like RNNs, where new states are calculated from previous ones. With jax, we can make these operations a lot more performant with `jax.lax.scan` and `jax.jit` (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html)
|
36 |
+
|
37 |
+
Instead of writing the nca growth process as:
|
38 |
+
|
39 |
+
```python
|
40 |
+
def multi_step(params, nca, current_state, num_steps):
|
41 |
+
# params: parameters for NCA
|
42 |
+
# nca: Flax Module describing NCA
|
43 |
+
# current_state: Current NCA state
|
44 |
+
# num_steps: number of steps to run
|
45 |
+
|
46 |
+
for i in range(num_steps):
|
47 |
+
current_state = nca.apply(params, current_state)
|
48 |
+
return current_state
|
49 |
+
```
|
50 |
+
|
51 |
+
We can write this with `jax.lax.scan`
|
52 |
+
|
53 |
+
```python
|
54 |
+
def multi_step(params, nca, current_state, num_steps):
|
55 |
+
# params: parameters for NCA
|
56 |
+
# nca: Flax Module describing NCA
|
57 |
+
# current_state: Current NCA state
|
58 |
+
# num_steps: number of steps to run
|
59 |
+
|
60 |
+
def forward(carry, inp):
|
61 |
+
carry = nca.apply({"params": params}, carry)
|
62 |
+
return carry, carry
|
63 |
+
|
64 |
+
final_state, nca_states = jax.lax.scan(forward, current_state, None, length=num_steps)
|
65 |
+
return final_state
|
66 |
+
```
|
67 |
+
The actual multi_step implementation can be found here: https://github.com/shyamsn97/jax-nca/blob/main/jax_nca/nca.py#L103
|
68 |
+
|
69 |
+
---
|
70 |
+
|
71 |
+
## Usage
|
72 |
+
See [notebooks/Gecko.ipynb](notebooks/Gecko.ipynb) for a full example
|
73 |
+
|
74 |
+
<b> Currently there's a bug with the stochastic update, so only `cell_fire_rate = 1.0` works at the moment </b>
|
75 |
+
|
76 |
+
Creating and using NCA:
|
77 |
+
|
78 |
+
```python
|
79 |
+
class NCA(nn.Module):
|
80 |
+
num_hidden_channels: int
|
81 |
+
num_target_channels: int = 3
|
82 |
+
alpha_living_threshold: float = 0.1
|
83 |
+
cell_fire_rate: float = 1.0
|
84 |
+
trainable_perception: bool = False
|
85 |
+
alpha: float = 1.0
|
86 |
+
|
87 |
+
"""
|
88 |
+
num_hidden_channels: Number of hidden channels for each cell to use
|
89 |
+
num_target_channels: Number of target channels to be used
|
90 |
+
alpha_living_threshold: threshold to determine whether a cell lives or dies
|
91 |
+
cell_fire_rate: probability that a cell receives an update per step
|
92 |
+
trainable_perception: if true, instead of using sobel filters use a trainable conv net
|
93 |
+
alpha: scalar value to be multiplied to updates
|
94 |
+
"""
|
95 |
+
...
|
96 |
+
|
97 |
+
from jax_nca.nca import NCA
|
98 |
+
|
99 |
+
# usage
|
100 |
+
nca = NCA(
|
101 |
+
num_hidden_channels = 16,
|
102 |
+
num_target_channels = 3,
|
103 |
+
trainable_perception = False,
|
104 |
+
cell_fire_rate = 1.0,
|
105 |
+
alpha_living_threshold = 0.1
|
106 |
+
)
|
107 |
+
|
108 |
+
nca_seed = nca.create_seed(
|
109 |
+
nca.num_hidden_channels, nca.num_target_channels, shape=(64,64), batch_size=1
|
110 |
+
)
|
111 |
+
rng = jax.random.PRNGKey(0)
|
112 |
+
params = = nca.init(rng, nca_seed, rng)["params"]
|
113 |
+
update = nca.apply({"params":params}, nca_seed, jax.random.PRNGKey(10))
|
114 |
+
|
115 |
+
# multi step
|
116 |
+
|
117 |
+
final_state, nca_states = nca.multi_step(poarams, nca_seed, jax.random.PRNGKey(10), num_steps=32)
|
118 |
+
```
|
119 |
+
|
120 |
+
To train the NCA:
|
121 |
+
```python
|
122 |
+
from jax_nca.dataset import ImageDataset
|
123 |
+
from jax_nca.trainer import EmojiTrainer
|
124 |
+
|
125 |
+
|
126 |
+
dataset = ImageDataset(emoji='🦎', img_size=64)
|
127 |
+
|
128 |
+
|
129 |
+
nca = NCA(
|
130 |
+
num_hidden_channels = 16,
|
131 |
+
num_target_channels = 3,
|
132 |
+
trainable_perception = False,
|
133 |
+
cell_fire_rate = 1.0,
|
134 |
+
alpha_living_threshold = 0.1
|
135 |
+
)
|
136 |
+
|
137 |
+
trainer = EmojiTrainer(dataset, nca, n_damage=0)
|
138 |
+
|
139 |
+
trainer.train(100000, batch_size=8, seed=10, lr=2e-4, min_steps=64, max_steps=96)
|
140 |
+
|
141 |
+
# to access train state:
|
142 |
+
|
143 |
+
state = trainer.state
|
144 |
+
|
145 |
+
# save
|
146 |
+
nca.save(state.params, "saved_params")
|
147 |
+
|
148 |
+
# load params
|
149 |
+
loaded_params = nca.load("saved_params")
|
150 |
+
|
151 |
+
```
|
images/gecko.gif
ADDED
images/model.svg
ADDED
jax_nca/__init__.py
ADDED
File without changes
|
jax_nca/dataset.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
import numpy as np
|
3 |
+
from einops import repeat
|
4 |
+
|
5 |
+
from jax_nca.utils import load_emoji
|
6 |
+
|
7 |
+
|
8 |
+
def to_alpha(x):
|
9 |
+
return np.clip(x[:, :, :, 3:4], 0.0, 1.0)
|
10 |
+
|
11 |
+
|
12 |
+
def rgb(x, rgb=False):
|
13 |
+
# assume rgb premultiplied by alpha
|
14 |
+
if rgb:
|
15 |
+
return np.clip(x[:, :, :, :3], 0.0, 1.0)
|
16 |
+
rgb, a = x[:, :, :, :3], to_alpha(x)
|
17 |
+
return np.clip(1.0 - a + rgb, 0.0, 1.0)
|
18 |
+
|
19 |
+
|
20 |
+
class ImageDataset:
|
21 |
+
def __init__(self, emoji: str = None, img: np.array = None, img_size: int = 64):
|
22 |
+
if img is None:
|
23 |
+
img = load_emoji(emoji, img_size)
|
24 |
+
self.rgb = img.shape[-1] == 3
|
25 |
+
self.img_shape = img.shape
|
26 |
+
self.img = np.expand_dims(img, 0) # (b w h c)
|
27 |
+
self.rgb_img = rgb(self.img, self.rgb)
|
28 |
+
|
29 |
+
def get_batch(self, batch_size: int = 1):
|
30 |
+
return repeat(
|
31 |
+
self.img, "b w h c -> (b repeat) w h c", repeat=batch_size
|
32 |
+
), repeat(self.rgb_img, "b w h c -> (b repeat) w h c", repeat=batch_size)
|
33 |
+
|
34 |
+
def visualize(self):
|
35 |
+
_ = plt.imshow(self.rgb_img[0])
|
36 |
+
plt.show()
|
jax_nca/nca.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
from typing import Tuple
|
3 |
+
|
4 |
+
import flax.linen as nn
|
5 |
+
import jax
|
6 |
+
import jax.numpy as jnp
|
7 |
+
import numpy as np
|
8 |
+
from flax import serialization
|
9 |
+
from jax import lax
|
10 |
+
|
11 |
+
|
12 |
+
class SobelPerceptionNet(nn.Module):
|
13 |
+
@nn.compact
|
14 |
+
def __call__(self, x):
|
15 |
+
# x shape - BHWC
|
16 |
+
|
17 |
+
num_channels = x.shape[-1]
|
18 |
+
|
19 |
+
# 2D sobel kernels - IOHW layout
|
20 |
+
|
21 |
+
x_sobel_kernel = jnp.zeros(
|
22 |
+
(num_channels, num_channels, 3, 3), dtype=jnp.float32
|
23 |
+
)
|
24 |
+
x_sobel_kernel += (
|
25 |
+
jnp.array([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]])[
|
26 |
+
jnp.newaxis, jnp.newaxis, :, :
|
27 |
+
]
|
28 |
+
/ 8.0
|
29 |
+
)
|
30 |
+
|
31 |
+
y_sobel_kernel = jnp.zeros(
|
32 |
+
(num_channels, num_channels, 3, 3), dtype=jnp.float32
|
33 |
+
)
|
34 |
+
y_sobel_kernel += (
|
35 |
+
jnp.array([[-1.0, -2.0, -1.0], [0.0, 0.0, 0.0], [1.0, 2.0, 1.0]])[
|
36 |
+
jnp.newaxis, jnp.newaxis, :, :
|
37 |
+
]
|
38 |
+
/ 8.0
|
39 |
+
)
|
40 |
+
x = jnp.transpose(x, [0, 3, 1, 2]) # N C H W
|
41 |
+
|
42 |
+
x_out = lax.conv(
|
43 |
+
x, # lhs = NCHW image tensor
|
44 |
+
x_sobel_kernel, # rhs = OIHW conv kernel tensor
|
45 |
+
(1, 1), # window strides
|
46 |
+
"SAME",
|
47 |
+
) # padding mode
|
48 |
+
|
49 |
+
y_out = lax.conv(
|
50 |
+
x, # lhs = NCHW image tensor
|
51 |
+
y_sobel_kernel, # rhs = OIHW conv kernel tensor
|
52 |
+
(1, 1), # window strides
|
53 |
+
"SAME",
|
54 |
+
) # padding mode
|
55 |
+
|
56 |
+
out = jnp.concatenate([x, x_out, y_out], axis=1)
|
57 |
+
return jnp.transpose(out, [0, 2, 3, 1]) # N H W C
|
58 |
+
|
59 |
+
|
60 |
+
class UpdateNet(nn.Module):
|
61 |
+
|
62 |
+
num_channels: int
|
63 |
+
|
64 |
+
@nn.compact
|
65 |
+
def __call__(self, x):
|
66 |
+
update_layer_1 = nn.Conv(
|
67 |
+
features=64, kernel_size=(1, 1), strides=1, padding="VALID"
|
68 |
+
)
|
69 |
+
update_layer_2 = nn.Conv(
|
70 |
+
features=64, kernel_size=(1, 1), strides=1, padding="VALID"
|
71 |
+
)
|
72 |
+
update_layer_3 = nn.Conv(
|
73 |
+
features=self.num_channels,
|
74 |
+
kernel_size=(1, 1),
|
75 |
+
strides=1,
|
76 |
+
padding="VALID",
|
77 |
+
kernel_init=jax.nn.initializers.zeros,
|
78 |
+
use_bias=False,
|
79 |
+
)
|
80 |
+
x = update_layer_1(x)
|
81 |
+
x = nn.relu(x)
|
82 |
+
x = update_layer_2(x)
|
83 |
+
x = nn.relu(x)
|
84 |
+
x = update_layer_3(x)
|
85 |
+
return x
|
86 |
+
|
87 |
+
|
88 |
+
class TrainablePerception(nn.Module):
|
89 |
+
num_channels: int
|
90 |
+
|
91 |
+
@nn.compact
|
92 |
+
def __call__(self, x):
|
93 |
+
out = nn.Conv(
|
94 |
+
features=self.num_channels * 3,
|
95 |
+
kernel_size=(3, 3),
|
96 |
+
use_bias=False,
|
97 |
+
feature_group_count=self.num_channels,
|
98 |
+
)(x)
|
99 |
+
return out
|
100 |
+
|
101 |
+
|
102 |
+
@functools.partial(jax.jit, static_argnames=("apply_fn", "num_steps"))
|
103 |
+
def nca_multi_step(apply_fn, params, current_state: jnp.array, rng, num_steps: int):
|
104 |
+
def forward(carry, inp):
|
105 |
+
carry = apply_fn({"params": params}, carry, rng)
|
106 |
+
return carry, carry
|
107 |
+
|
108 |
+
x, outs = jax.lax.scan(forward, current_state, None, length=num_steps)
|
109 |
+
return x, outs
|
110 |
+
|
111 |
+
|
112 |
+
class NCA(nn.Module):
|
113 |
+
num_hidden_channels: int
|
114 |
+
num_target_channels: int = 3
|
115 |
+
alpha_living_threshold: float = 0.1
|
116 |
+
cell_fire_rate: float = 1.0
|
117 |
+
trainable_perception: bool = False
|
118 |
+
alpha: float = 1.0
|
119 |
+
|
120 |
+
"""
|
121 |
+
num_hidden_channels: Number of hidden channels for each cell to use
|
122 |
+
num_target_channels: Number of target channels to be used
|
123 |
+
alpha_living_threshold: threshold to determine whether a cell lives or dies
|
124 |
+
cell_fire_rate: probability that a cell receives an update per step
|
125 |
+
trainable_perception: if true, instead of using sobel filters use a trainable conv net
|
126 |
+
alpha: scalar value to be multiplied to updates
|
127 |
+
"""
|
128 |
+
|
129 |
+
@classmethod
|
130 |
+
def create_seed(
|
131 |
+
cls,
|
132 |
+
num_hidden_channels: int,
|
133 |
+
num_target_channels: int = 3,
|
134 |
+
shape: Tuple[int] = (48, 48),
|
135 |
+
batch_size: int = 1,
|
136 |
+
):
|
137 |
+
seed = np.zeros((batch_size, *shape, num_hidden_channels + 3 + 1))
|
138 |
+
w, h = seed.shape[1], seed.shape[2]
|
139 |
+
seed[:, w // 2, h // 2, 3:] = 1.0
|
140 |
+
return seed
|
141 |
+
|
142 |
+
def setup(self):
|
143 |
+
num_channels = 3 + self.num_hidden_channels + 1
|
144 |
+
if self.trainable_perception:
|
145 |
+
self.perception = TrainablePerception(num_channels)
|
146 |
+
else:
|
147 |
+
self.perception = SobelPerceptionNet()
|
148 |
+
self.update_net = UpdateNet(num_channels)
|
149 |
+
|
150 |
+
def alive(self, x, alpha_living_threshold: float):
|
151 |
+
return (
|
152 |
+
nn.max_pool(
|
153 |
+
x[..., 3:4], window_shape=(3, 3), strides=(1, 1), padding="SAME"
|
154 |
+
)
|
155 |
+
> alpha_living_threshold
|
156 |
+
)
|
157 |
+
|
158 |
+
def get_stochastic_update_mask(self, x, rng, cell_fire_rate: float = 1.0):
|
159 |
+
return jnp.array(np.random.uniform(size=x[..., :1].shape) <= cell_fire_rate)
|
160 |
+
|
161 |
+
def __call__(self, x, rng):
|
162 |
+
pre_life_mask = self.alive(x, self.alpha_living_threshold)
|
163 |
+
|
164 |
+
perception_out = self.perception(x)
|
165 |
+
update = self.alpha * jnp.reshape(self.update_net(perception_out), x.shape)
|
166 |
+
|
167 |
+
if self.cell_fire_rate >= 1.0:
|
168 |
+
stochastic_update_mask = self.get_stochastic_update_mask(
|
169 |
+
x, rng, self.cell_fire_rate
|
170 |
+
).astype(float)
|
171 |
+
x = x + update * stochastic_update_mask
|
172 |
+
else:
|
173 |
+
x = x + update
|
174 |
+
|
175 |
+
post_life_mask = self.alive(x, self.alpha_living_threshold)
|
176 |
+
|
177 |
+
life_mask = pre_life_mask & post_life_mask
|
178 |
+
life_mask = life_mask.astype(float)
|
179 |
+
|
180 |
+
return x * life_mask
|
181 |
+
|
182 |
+
def save(self, params, path: str):
|
183 |
+
bytes_output = serialization.to_bytes(params)
|
184 |
+
with open(path, "wb") as f:
|
185 |
+
f.write(bytes_output)
|
186 |
+
|
187 |
+
def load(self, path: str):
|
188 |
+
nca_seed = self.create_seed(
|
189 |
+
self.num_hidden_channels, self.num_target_channels, batch_size=1
|
190 |
+
)
|
191 |
+
rng = jax.random.PRNGKey(0)
|
192 |
+
init_params = self.init(rng, nca_seed, rng)["params"]
|
193 |
+
with open(path, "rb") as f:
|
194 |
+
bytes_output = f.read()
|
195 |
+
return serialization.from_bytes(init_params, bytes_output)
|
196 |
+
|
197 |
+
def multi_step(self, params, current_state: jnp.array, rng, num_steps: int = 2):
|
198 |
+
return nca_multi_step(self.apply, params, current_state, rng, num_steps)
|
199 |
+
|
200 |
+
def to_rgb(self, x: jnp.array):
|
201 |
+
rgb, a = x[..., :3], jnp.clip(x[..., 3:4], 0.0, 1.0)
|
202 |
+
rgb = jnp.clip(1.0 - a + rgb, 0.0, 1.0)
|
203 |
+
return rgb
|
jax_nca/trainer.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import os
|
3 |
+
from collections.abc import Iterable
|
4 |
+
from datetime import datetime
|
5 |
+
|
6 |
+
import jax
|
7 |
+
import jax.numpy as jnp
|
8 |
+
import numpy as np
|
9 |
+
import optax
|
10 |
+
import pandas as pd
|
11 |
+
import tqdm
|
12 |
+
from flax.training import train_state # Useful dataclass to keep train state
|
13 |
+
from tensorboardX import SummaryWriter
|
14 |
+
|
15 |
+
from jax_nca.utils import make_circle_masks
|
16 |
+
|
17 |
+
|
18 |
+
def get_tensorboard_logger(
|
19 |
+
experiment_name: str, base_log_path: str = "tensorboard_logs"
|
20 |
+
):
|
21 |
+
log_path = "{}/{}_{}".format(base_log_path, experiment_name, datetime.now())
|
22 |
+
train_writer = SummaryWriter(log_path, flush_secs=10)
|
23 |
+
full_log_path = os.path.join(os.getcwd(), log_path)
|
24 |
+
print(
|
25 |
+
"Follow tensorboard logs with: python -m tensorboard.main --logdir '{}'".format(
|
26 |
+
full_log_path
|
27 |
+
)
|
28 |
+
)
|
29 |
+
return train_writer
|
30 |
+
|
31 |
+
|
32 |
+
def create_train_state(rng, nca, learning_rate, shape):
|
33 |
+
nca_seed = nca.create_seed(
|
34 |
+
nca.num_hidden_channels, nca.num_target_channels, shape=shape[:-1], batch_size=1
|
35 |
+
)
|
36 |
+
"""Creates initial `TrainState`."""
|
37 |
+
params = nca.init(rng, nca_seed, rng)["params"]
|
38 |
+
tx = optax.chain(
|
39 |
+
# optax.clip_by_global_norm(10.0),
|
40 |
+
optax.adam(learning_rate),
|
41 |
+
)
|
42 |
+
return train_state.TrainState.create(apply_fn=nca.apply, params=params, tx=tx)
|
43 |
+
|
44 |
+
|
45 |
+
def clip_grad_norm(grad):
|
46 |
+
factor = 1.0 / (
|
47 |
+
jnp.linalg.norm(jax.tree_util.tree_leaves(jax.tree_map(jnp.linalg.norm, grad)))
|
48 |
+
+ 1e-8
|
49 |
+
)
|
50 |
+
return jax.tree_map((lambda x: x * factor), grad)
|
51 |
+
|
52 |
+
|
53 |
+
@functools.partial(jax.jit, static_argnames=("apply_fn", "num_steps"))
|
54 |
+
def train_step(
|
55 |
+
apply_fn, state, seeds: jnp.array, targets: jnp.array, num_steps: int, rng
|
56 |
+
):
|
57 |
+
def mse_loss(pred, y):
|
58 |
+
squared_diff = jnp.square(pred - y)
|
59 |
+
return jnp.mean(squared_diff, axis=[-3, -2, -1])
|
60 |
+
|
61 |
+
def loss_fn(params):
|
62 |
+
def forward(carry, inp):
|
63 |
+
carry = apply_fn({"params": params}, carry, rng)
|
64 |
+
return carry, carry
|
65 |
+
|
66 |
+
x, outs = jax.lax.scan(forward, seeds, None, length=num_steps)
|
67 |
+
rgb, a = x[..., :3], jnp.clip(x[..., 3:4], 0.0, 1.0)
|
68 |
+
rgb = jnp.clip(1.0 - a + rgb, 0.0, 1.0)
|
69 |
+
|
70 |
+
outs = jnp.transpose(outs, [1, 0, 2, 3, 4])
|
71 |
+
subset = outs[:, -8:] # B 12 H W C
|
72 |
+
return jnp.mean(
|
73 |
+
jax.vmap(mse_loss)(subset[..., :4], jnp.expand_dims(targets, 1))
|
74 |
+
), (x, rgb)
|
75 |
+
|
76 |
+
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
|
77 |
+
(loss, aux), grads = grad_fn(state.params)
|
78 |
+
grads = clip_grad_norm(grads)
|
79 |
+
updated, rgb = aux
|
80 |
+
return state.apply_gradients(grads=grads), loss, grads, updated, rgb
|
81 |
+
|
82 |
+
|
83 |
+
class SamplePool:
|
84 |
+
def __init__(self, max_size: int = 1000):
|
85 |
+
self.max_size = max_size
|
86 |
+
self.pool = [None] * max_size
|
87 |
+
|
88 |
+
def __getitem__(self, idx):
|
89 |
+
if isinstance(idx, Iterable):
|
90 |
+
return [self.pool[i] for i in idx]
|
91 |
+
return idx
|
92 |
+
|
93 |
+
def __setitem__(self, idx, v):
|
94 |
+
if isinstance(idx, Iterable):
|
95 |
+
for i in range(len(idx)):
|
96 |
+
index = idx[i]
|
97 |
+
self.pool[index] = v[i]
|
98 |
+
else:
|
99 |
+
self.pool[idx] = v
|
100 |
+
|
101 |
+
def sample(self, num_samples: int):
|
102 |
+
indices = np.random.randint(0, self.max_size, num_samples)
|
103 |
+
return self.__getitem__(indices), indices
|
104 |
+
|
105 |
+
|
106 |
+
def flatten(d):
|
107 |
+
df = pd.json_normalize(d, sep="_")
|
108 |
+
return df.to_dict(orient="records")[0]
|
109 |
+
|
110 |
+
|
111 |
+
class EmojiTrainer:
|
112 |
+
def __init__(self, dataset, nca, pool_size: int = 1024, n_damage: int = 0):
|
113 |
+
self.dataset = dataset
|
114 |
+
self.img_shape = self.dataset.img_shape
|
115 |
+
self.nca = nca
|
116 |
+
self.pool_size = pool_size
|
117 |
+
self.n_damage = n_damage
|
118 |
+
self.state = None
|
119 |
+
|
120 |
+
def train(
|
121 |
+
self,
|
122 |
+
num_epochs,
|
123 |
+
batch_size: int = 8,
|
124 |
+
seed: int = 10,
|
125 |
+
lr: float = 0.001,
|
126 |
+
min_steps: int = 64,
|
127 |
+
max_steps: int = 96,
|
128 |
+
):
|
129 |
+
pool = SamplePool(self.pool_size)
|
130 |
+
|
131 |
+
writer = get_tensorboard_logger("EMOJITrainer")
|
132 |
+
rng = jax.random.PRNGKey(seed)
|
133 |
+
rng, init_rng = jax.random.split(rng)
|
134 |
+
self.state = create_train_state(init_rng, self.nca, lr, self.dataset.img_shape)
|
135 |
+
|
136 |
+
bar = tqdm.tqdm(np.arange(num_epochs))
|
137 |
+
try:
|
138 |
+
for i in bar:
|
139 |
+
num_steps = int(np.random.randint(min_steps, max_steps))
|
140 |
+
samples, indices = pool.sample(batch_size)
|
141 |
+
for j in range(len(samples)):
|
142 |
+
if samples[j] is None:
|
143 |
+
samples[j] = self.nca.create_seed(
|
144 |
+
self.nca.num_hidden_channels,
|
145 |
+
self.nca.num_target_channels,
|
146 |
+
shape=self.img_shape[:-1],
|
147 |
+
batch_size=1,
|
148 |
+
)[0]
|
149 |
+
samples[0] = self.nca.create_seed(
|
150 |
+
self.nca.num_hidden_channels,
|
151 |
+
self.nca.num_target_channels,
|
152 |
+
shape=self.img_shape[:-1],
|
153 |
+
batch_size=1,
|
154 |
+
)[0]
|
155 |
+
batch = np.stack(samples)
|
156 |
+
if self.n_damage > 0:
|
157 |
+
damage = (
|
158 |
+
1.0
|
159 |
+
- make_circle_masks(
|
160 |
+
int(self.n_damage), self.img_shape[0], self.img_shape[1]
|
161 |
+
)[..., None]
|
162 |
+
)
|
163 |
+
batch[-self.n_damage :] *= damage
|
164 |
+
|
165 |
+
batch = jnp.array(batch)
|
166 |
+
targets, rgb_targets = self.dataset.get_batch(batch_size)
|
167 |
+
targets = jnp.array(targets)
|
168 |
+
|
169 |
+
self.state, loss, grads, outputs, rgb_outputs = train_step(
|
170 |
+
self.nca.apply,
|
171 |
+
self.state,
|
172 |
+
batch,
|
173 |
+
targets,
|
174 |
+
num_steps=num_steps,
|
175 |
+
rng=rng,
|
176 |
+
)
|
177 |
+
|
178 |
+
grad_dict = {k: dict(grads[k]) for k in grads.keys()}
|
179 |
+
grad_dict = flatten(grad_dict)
|
180 |
+
|
181 |
+
grad_dict = {
|
182 |
+
k: {kk: np.sum(vv).item() for kk, vv in v.items()}
|
183 |
+
for k, v in grad_dict.items()
|
184 |
+
}
|
185 |
+
grad_dict = flatten(grad_dict)
|
186 |
+
|
187 |
+
pool[indices] = np.array(outputs)
|
188 |
+
|
189 |
+
bar.set_description("Loss: {}".format(loss.item()))
|
190 |
+
|
191 |
+
self.emit_metrics(
|
192 |
+
writer,
|
193 |
+
i,
|
194 |
+
batch,
|
195 |
+
rgb_outputs,
|
196 |
+
rgb_targets,
|
197 |
+
loss.item(),
|
198 |
+
metrics=grad_dict,
|
199 |
+
)
|
200 |
+
|
201 |
+
return self.state
|
202 |
+
except Exception:
|
203 |
+
return self.state
|
204 |
+
|
205 |
+
def emit_metrics(
|
206 |
+
self, train_writer, i: int, batch, outputs, targets, loss, metrics={}
|
207 |
+
):
|
208 |
+
train_writer.add_scalar("loss", loss, i)
|
209 |
+
# train_writer.add_scalar("log10(loss)", math.log10(loss), i)
|
210 |
+
train_writer.add_images("batch", self.nca.to_rgb(batch), i, dataformats="NHWC")
|
211 |
+
train_writer.add_images("outputs", outputs, i, dataformats="NHWC")
|
212 |
+
train_writer.add_images("targets", targets, i, dataformats="NHWC")
|
213 |
+
for k in metrics:
|
214 |
+
train_writer.add_scalar(k, metrics[k], i)
|
jax_nca/utils.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import PIL.Image
|
5 |
+
import requests
|
6 |
+
|
7 |
+
|
8 |
+
def load_image(url, size):
|
9 |
+
r = requests.get(url)
|
10 |
+
img = PIL.Image.open(io.BytesIO(r.content))
|
11 |
+
img.thumbnail((40, 40), PIL.Image.ANTIALIAS)
|
12 |
+
img = np.float32(img) / 255.0
|
13 |
+
# premultiply RGB by Alpha
|
14 |
+
img[..., :3] *= img[..., 3:]
|
15 |
+
# pad to self.h, self.h
|
16 |
+
diff = size - 40
|
17 |
+
img = np.pad(img, ((diff // 2, diff // 2), (diff // 2, diff // 2), (0, 0)))
|
18 |
+
return img
|
19 |
+
|
20 |
+
|
21 |
+
def load_emoji(emoji, size, code=None):
|
22 |
+
if code is None:
|
23 |
+
code = hex(ord(emoji))[2:].lower()
|
24 |
+
url = (
|
25 |
+
"https://github.com/googlefonts/noto-emoji/blob/main/png/128/emoji_u%s.png?raw=true"
|
26 |
+
% code
|
27 |
+
)
|
28 |
+
return load_image(url, size)
|
29 |
+
|
30 |
+
|
31 |
+
def make_circle_masks(n, h, w, r=None):
|
32 |
+
x = np.linspace(-1.0, 1.0, w)[None, None, :]
|
33 |
+
y = np.linspace(-1.0, 1.0, h)[None, :, None]
|
34 |
+
center = np.random.uniform(-0.5, 0.5, size=[2, n, 1, 1])
|
35 |
+
if r is None:
|
36 |
+
r = np.random.uniform(0.1, 0.4, size=[n, 1, 1])
|
37 |
+
x, y = (x - center[0]) / r, (y - center[1]) / r
|
38 |
+
mask = x * x + y * y < 1.0
|
39 |
+
return mask.astype(float)
|
notebooks/Gecko.ipynb
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"source": [
|
7 |
+
"import abc\n",
|
8 |
+
"import random\n",
|
9 |
+
"import numpy as np\n",
|
10 |
+
"import matplotlib.pyplot as plt\n",
|
11 |
+
"import tensorflow as tf\n",
|
12 |
+
"from einops import repeat, rearrange\n",
|
13 |
+
"tf.config.experimental.set_visible_devices([], 'GPU')\n",
|
14 |
+
"\n",
|
15 |
+
"# uncomment this to enable jax gpu preallocation, might lead to memory issues\n",
|
16 |
+
"\n",
|
17 |
+
"import os\n",
|
18 |
+
"os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"false\""
|
19 |
+
],
|
20 |
+
"outputs": [],
|
21 |
+
"metadata": {}
|
22 |
+
},
|
23 |
+
{
|
24 |
+
"cell_type": "markdown",
|
25 |
+
"source": [
|
26 |
+
"# Gecko"
|
27 |
+
],
|
28 |
+
"metadata": {}
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"cell_type": "code",
|
32 |
+
"execution_count": 2,
|
33 |
+
"source": [
|
34 |
+
"from jax_nca.dataset import ImageDataset"
|
35 |
+
],
|
36 |
+
"outputs": [],
|
37 |
+
"metadata": {}
|
38 |
+
},
|
39 |
+
{
|
40 |
+
"cell_type": "code",
|
41 |
+
"execution_count": 3,
|
42 |
+
"source": [
|
43 |
+
"dataset = ImageDataset(emoji='🦎', img_size=64)"
|
44 |
+
],
|
45 |
+
"outputs": [],
|
46 |
+
"metadata": {}
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"cell_type": "code",
|
50 |
+
"execution_count": 4,
|
51 |
+
"source": [
|
52 |
+
"dataset.visualize()"
|
53 |
+
],
|
54 |
+
"outputs": [
|
55 |
+
{
|
56 |
+
"output_type": "display_data",
|
57 |
+
"data": {
|
58 |
+
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD7CAYAAACscuKmAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAcMElEQVR4nO3de3RV1Z0H8O8vT8IzQQIEgjwERQYFnCugRUdQLFVb1DrWR0dWS8u00zq22qq0s2a0ra22tdVZa8aWVa1MfeCjtVp0oYhanwWDKAKBgrwDgfAIEB4Jyf3NH/dw9t7H3OSS3HsT2N/PWix+9+5z79nJze+evc/ZZ29RVRDRyS+noytARNnBZCfyBJOdyBNMdiJPMNmJPMFkJ/JEu5JdRKaJyBoRWScid6arUkSUftLW6+wikgvg7wCmAtgK4H0A16vqqvRVj4jSJa8drx0PYJ2qrgcAEZkHYDqApMnep08fHTJkSDt2SUQt2bhxI3bt2iXNlbUn2QcC2GI93gpgQksvGDJkCCoqKtqxSyJqSSwWS1qW8RN0IjJLRCpEpKKmpibTuyOiJNqT7FUABlmPy4PnHKo6R1VjqhorLS1tx+6IqD3ak+zvAxghIkNFpADAdQBeSE+1iCjd2txnV9VGEfk2gJcB5AJ4RFVXpq1mRJRW7TlBB1V9CcBLaaoLEWUQR9AReYLJTuQJJjuRJ5jsRJ5gshN5gslO5AkmO5EnmOxEnmCyE3mCyU7kCSY7kSeY7ESeYLITeYLJTuQJJjuRJ5jsRJ5gshN5gslO5AkmO5EnmOxEnmCyE3mCyU7kCSY7kSeY7ESeYLITeaLVZBeRR0Rkp4issJ7rLSILRWRt8H9JZqtJRO2VypH9UQDTIs/dCWCRqo4AsCh4TESdWKvJrqpvAtgTeXo6gLlBPBfAlemtFhGlW1v77P1UdXsQVwPol6b6EFGGtPsEnaoqAE1WLiKzRKRCRCpqamrauzsiaqO2JvsOESkDgOD/nck2VNU5qhpT1VhpaWkbd0dE7dXWZH8BwIwgngHg+fRUh4gyJZVLb08CeA/AGSKyVURmArgXwFQRWQvgkuAxEXViea1toKrXJym6OM11IaIMajXZKXMS5zYNETnu16X6GiIOlyXyBJOdyBNsxmeZInkTvLbhQBjfv+65MB7Spa+z3cxhZvRyPB53yuz3ZBOfbDyyE3mCyU7kCSY7kSfYZ88wjdw2ELcum8W1ySmb/s5dYfy2rjMFje57LK1dG8b/e87NTlmTmj58LthnJ4NHdiJPMNmJPMFmfIbFI6PkcsV8v95T+YRT9vZR0zwv2G+a43Fx3+O3+koYD17lXpa7Y9SXwthp0gu/133HvwAiTzDZiTzBZnwGxJG8+Vy5b1MY/2zNU05ZTtycnW+w3gPuIDnk1plm/Q82/sEpO6tkWBhfVnZuGNtN+ubqZbNH5UmOOaMf7ZLk2KP1eOa/0+ORncgTTHYiTzDZiTzBPnsGOF3bSFf2/rV/CuOGXu53rd0XR4PVx468R9x6mextcMq+sezBMF7e5zdh3Cuvm1tH++67yA5ycpo/BuTyLroTGo/sRJ5gshN5gs34NIje7GJf1qprPOSULaipCOOc1XvdN2qy3meYaXZrQa77/m/uCOOuc9Y4ZTtvGB7Gj4x+NYxvPfVKZ7uj8UZTj8hluE0Hq83753UJ41X7Nznbjeg+MIzLuyZfE4CX5ToHHtmJPMFkJ/IEk53IE+yzp4HGI/O/W0NMNx91V7uuXr81jItuq3DKUGT65ocfMENddUh3Z7Oc3fVmXwcOu/veaR7P37EkjL87aLqzXX5O8o8+9qqZEKOkwOx7Y3y3s92t1nv+YszXnbJG65xAXgv7ouxJZfmnQSLyuoisEpGVInJL8HxvEVkoImuD/0syX10iaqtUmvGNAG5T1VEAJgL4loiMAnAngEWqOgLAouAxEXVSqaz1th3A9iA+ICKVAAYCmA7gomCzuQDeAHBHRmrZyWkLV5YOHKlzn+hZEIbx4T3csi6mGa9drY+m7qizWcNnB5j3GNrTKYuPMO+5YbfpMkSXmtrVsC+MH9v8mlN2dvHQMF5Xty2MRxUMdLb7/aaXw7hXgTtC7z/OvCGMOYlG53Bcv3kRGQJgHIDFAPoFXwQAUA2gX3qrRkTplHKyi0h3AH8E8B1V3W+XaeKwoUleN0tEKkSkoqampl2VJaK2SynZRSQfiUR/XFWP3cmxQ0TKgvIyADube62qzlHVmKrGSkuTj7Iiosxqtc8uiQXDHgZQqaq/sopeADADwL3B/89npIYnAIm2aaw+fN+8Xk5Rbs/CMD5yz1j3ddb7aL75Hs6PzBufn2fK6se6F0Hsy4CHG8wluuidbO/trgzj7+12J74cfKhrGO8+Yhpx1Udrne3O7DkojMf2GuaUuWvagTqBVC6AfgbAvwD4WEQ+DJ77ARJJ/rSIzASwCcC1GakhEaVFKmfj38an7qgOXZze6hBRpnBoUxpEl0a2L3OVd3fPUwwv7B/Gq3O2O2WoN5eoiqw74Eq6uXe9dc81+6s/7E4kudmK8/PMx9tgjWgDgItLx4bx443fdsp+uv9J87oiU4+BuX2c7V6bdG8Ylxa63Ql70spkk2G0xn4P+/JmtNtkl/HSXnL8zRB5gslO5Ak249Mg2oxvbDJN5vxc91f81WHTwviO9XPd97Ga5DmF5nt4/CF3tdch1u7WlBQ6ZQ0NZtsSMaPpCiI3o9jN3etOvcgpi/U6LYxHPHpJGA8eOsDZzm66f2pe+jY03aOTgCRt/vPsfpvwyE7kCSY7kSeY7ESeYJ89A3JzzaUy+/IRAHxz2BVh/OiGV5yyyqKqMD5idUyv2HjA2e6CI+acwENTy5wysfq9q4/UhvGGg+5lvqHd3NfZPqj6yLzfQTMZRuUnS53t9jSY0XW9C9y771qal97W0iW6dXXm97G8dkMYnxap+6G4GSk48ZQzk+7L94kveWQn8gSTncgTbMZngN1cjF5Osudhf3LCbKfswre/H8b71TSfHzrdbSJX9jIf26hC9/u6MC8/jPfkHwnjS968zdnum8OuDuMNG5c7ZXOWmEuC2sX8LAVwR/LlWKP8opNjpNpitpvuR5vcUX73VJqRfP+3640w7llf4Gxn/753feGZpO+fatfiZMUjO5EnmOxEnmCyE3mCffYMi15OaoibJZZHWxM7AsCi881dZJ9fcncYV6m7JtxWq398eo7b95y2zkxwWWpNZDGvr9unvmuD6ZcfPOwu+5xTbi5txfLNENk/XvaAs11xkZmYozHS386zhgk3NJr3j64r997uVWF8zcf3OWVXFY8P4/EFZghvY747fLj2qPmZJ7/5fafsxUk/DuNueUVhHD2X4kMfnkd2Ik8w2Yk8wWZ8lhXkFCQtG97DNJnfsCaGuOSt7znbVTUeDOMdR9wm7ZXLasO4cKiZy339ae5HXVhv5qJfIflO2b4SM0lFVRfT9J1T/bqz3e3dvhjGPQvdJarsu+AK8pL/zAO7m8k8+nU9xSn7aJ8ZNVd1eFcY16vbZbDvsNvd4I42jNyM5zUe2Yk8wWQn8gSb8RkWndThR6seC+OjkXnhHtlollPadrkZPXZF/0nOdn+pNTfQbDjinlV+6YtmiaaiAvNdHotMR92zqykrzHXLVh0w9dpeb1ahvWfrs249qv8Wxo+OvdUpG9d7RBjPX7sojF/b/K6zXf8upul+a9cLnLL795nfR1VTbRh3EffP9qHRZg69mwZPBTWPR3YiTzDZiTzBZCfyBPvsaRAdjWWLztf+cnVFGO+qd9bHRK88c6msrtHc9XbT4CnOdg9vXRDGy/Pc9++fby5zndndfLynr9rnbHf5e+ZS1iNXlztlxVZfv36HmRhidYO7r+XdtoTxxe/8wCl7aPS/hvGsZ02fen9RPVzWyLVGd2lqOcWcf5DhJj6vabiznd1Pb2ElLu+1emQXkS4iskREPhKRlSJyd/D8UBFZLCLrROQpEUl+MZWIOlwqzfh6AFNUdQyAsQCmichEAPcB+LWqDgewF8DMjNWSiNotlbXeFMCxOw3yg38KYAqAG4Ln5wK4C8BD6a9i52evnAq4N79c9c5dTlnPQtNUP2A11QFgnZr169/dY1ZZndY/5mx3Vr5ZPXUttjhlfazLZt27mMkmSvu488uvnWAueQ3s55YVWstQffkvZu6690rckXa3jjfvsavuoFN2w/u/COPcQlOP3Dr3UmS/7v3CuLx3f6esoJ8ZUdjTWjbr9qFXOdvFYS0TFWnHczkoI9X12XODFVx3AlgI4BMAtarhuMWtAAYmeTkRdQIpJbuqNqnqWADlAMYDGJnqDkRklohUiEhFTU1N6y8goow4rjaOqtYCeB3AeQCKRcKhTOUAqpK8Zo6qxlQ1Vlpa2twmRJQFrfbZRaQUwFFVrRWRIgBTkTg59zqAawDMAzADwPOZrOiJ6nNl5zqP5255NYxX17n9bRSbvvOyfevDOC/SH16x5s0wrh/oXjZblm++v/N3mQkn0bfI2W5Eb9P/HvXmLqds05jiMF45ydwBFy9wjw2f72n+fJZFDhsf1ZsnastM31urNjjbPXy5mbBi2ojJaDdea0sqlevsZQDmikguEi2Bp1V1voisAjBPRH4CYBmAhzNYTyJqp1TOxi8HMK6Z59cj0X8nohMAR9ClgbbQdLywz1nO41srfxfGk/uc7ZStOrg1jH9e+UQYN65e42x3uNFM0CBb3UtZOwcPCeP3rSWgD+9155nru9OMZBv5gTvH3baRZqnnquFmUoq+mw45251jXdqLbXHLlteZS4C/79c1jHfVu8s+jzzFzC0XnXu+KW4m5hDrcqZELq8lXdqZHPwtEXmCyU7kCTbj0yCnhVPA0RVHn/lHs+TT1QPdSSmerno7jK97644w1gb3hhl0M7chPDDhu05R/0FnmPdYauaxO1Tn3sRypJs5G7/iKnc8VFmxKRuz0VwJOPelbc52L88cFsaj9rk3sQytNd2G5UNNV+CD/sXOdq/UrgjjWb1PdcrsprszEo5n3NuER3YiTzDZiTzBZCfyBPvsaSCSvBPZs6Cb89jup8fj7mWza62yPxeb+OUSdwTdjyeayR3/7dybku67e44Zkfe15f/tlP0t11y++6TRrcfIbeZuvO3WBBjV17p96i555lixItbbKWuyfiVnW29fF3cvAf680kzA+aVB/+SU9cw3l+x8X245HXhkJ/IEk53IE2zGZ1h0VFjcao7m5CRvjv7h0p+G8e4L3JVJ+3azbk6JdAXs0XyXDZgQxkuLH3S2++1mMyf7w+sXOGXvHjYj6jYfNaPYNvfIdbabWG/Kyru6f0o9rUtvVy+vDeP603o621X3MCPvntn0V6fsa8MvC+NGay6/vBz+2bYFj+xEnmCyE3mCyU7kCXZ+Mix6WS43xctGuTmmf2z30QH3bjB7uyi7n1vW1X2P6d1Hh/G8rY84ZVsbd4fxJ73NpJIH1J0A44jVZ5+U5x43xlqTVg6sNpNonDbC7bOX55vfxwvb3nDK7D57Do9L7cbfIJEnmOxEnmAzvpNyLtlFWv4tNd2TLUVV1+COwrvoiRvCeH+OWyYHTfO/qNaUHRrgznf3bqkZHZi3113WqUeJGb23//rBYdwUGa1Xtsdcolt6YK1Ttu2QmRtvgNUNif6MHFGXGh7ZiTzBZCfyBJvxnVRLN9e0+DqrSWuPNOte4H7Uv5hiJsf491fvdsrq88xZ9nNGm5tTfnf+bGe712vNElV3rf2NUzbgoHmPyVare9gOd8mrtdbIu7wit4n/Qe06835WMz4eGZWY28bflW94ZCfyBJOdyBNMdiJPsM9+AmppIoea+towvuNjs0hPrPh0Z7vivmbZ5x4DhjllDYf3hPG7cTPJ5NdXuf3yNy80yzJX7neXsnrryCthPK7WTEZ59sp97nbnmr54YaH7s6w5sCmMr8DEMFZ1+/bgsswpSfm3FCzbvExE5gePh4rIYhFZJyJPiUhBa+9BRB3neL4SbwFQaT2+D8CvVXU4gL0AZqazYkSUXik140WkHMDlAO4BcKskrgtNAXBsGNZcAHcBeCgDdaSIlprx260m+NzNC8P4rzuXO9ttPLTDPLBGuwFATo/+Zl9dzPHgnZqPne02HTTvce8Y97t+9CtvhPGSXmaUXO3kfs52DfXmZ8lrcpvnGw5uR7OaHyRIrUj1yP4AgNsBHPs0TgFQq6rHxlVuBTCwmdcRUSfRarKLyBUAdqrq0rbsQERmiUiFiFTU1NS05S2IKA1SObJ/BsAXRGQjgHlINN8fBFAsIse6AeUAqpp7sarOUdWYqsZKS0vTUGUiaotU1mefDWA2AIjIRQC+p6o3isgzAK5B4gtgBoDnM1dNstkTOUQnnDy72FxG+6+RXw7j9/a7yz7bff1t9budsnpr0ovvll8ZxjMHXepsN6jIfHlHl02e3n9yGC85ZCa3zDvo1teeRb4JruojtWiORPZln8OIDqW1l3f2fWnn9vz0dyBxsm4dEn34h1vZnog60HENqlHVNwC8EcTrAYxPf5WIKBM4gu4EF23S2l7ctjiMNxx1T45+c7CZ3+3pLe587fayS78c/fWk79/YZJr7OepeAvzKkM+G8bPvzg/jvHx3u3yrO3Eozy3bD/cOuWMkcunNvkPwU3fAWQ/tLo+PTXr/fmIiTzHZiTzBZvwJL/nyUk+f9x9hvHSPO7/bP3/0szCe2nOMU7bqwOYwXnvAXFE9rVuZs12OdQNK9Cz4WSVDzeu6mptw1h5161HUYF63NzKCrmvczH8XR/Im+J6G/WH84f4NTtnAQrO67Bk9BsFnPLITeYLJTuQJJjuRJ9hnP8F9es5083hwN3OHWb64c81/dcDUMD6/5Eyn7MXqJWFcXGDmho/2le2Ra/aSVIA7yu/GwWZf31jj9tnzjpjXHc11+/1ab+56O9xo5qXff/SQs13stZvDuLrQnQM/77B5z9lDrwnju0bf5Gznw2W5k/OnIqJPYbITeUJUszcTQCwW04qKiqztz3fZbJq2tCRTbUNdGI9c4E5yUdNkLpvlxt0uSZM1J93Cc+8J4wm9z3C2W2rNL790z9+dstsrzQq1cWv03nvn/dLZbnzvkWY7uJcAT6QVZGOxGCoqKpqdSP/E+SmIqF2Y7ESeYLITeYKX3k5idj892qe2h7dGbxSzT+PkWH3vltafi14CbLLmdi8u6B7G3zn9ame7H256LIzz6pwiNFm3t92/9tkwfnHST5ztLiw9q9kYAO5d81QY7847EsZL97qXAJ0+e2TYbk7uyXFMPDl+CiJqFZOdyBNsxnsi2sxucZnjNKyAbDf/7UuAt4yY7mz3xKbXwnhlkTtnaYE1UG5B07IwvnHxvc52Vw48P4z/su1vTtnuxgNhrLlmFGHfwuKkdW9pQpAT2cn5UxHRpzDZiTzBZjxlhH3m3h6lWZTbxdnumfPNBBuXvDPbKdvW3az4mlNnbpiZt+MtZzvncTwyIrTQarrHzVWBKX3HOZvZdcxpqYtzAuORncgTTHYiTzDZiTzBPjtlnD2SL3pHmT0J5Gvnu5fUbv7YrAC+sOFDUxCZe955ywL3+FUmvcL4sXHfD+MSa1QfELlDUE7OY2Cq67NvBHAAieW4GlU1JiK9ATwFYAiAjQCuVdW9makmEbXX8XyFTVbVsaoaCx7fCWCRqo4AsCh4TESdVHua8dMBXBTEc5FYA+6OdtaHTnLRiSDs5vOInuVO2YLPmAkr3qr5OIz/asUAUHPUXKL7h15DnbKr+k0I49Kikmb3C5y8887ZUv0JFcArIrJURGYFz/VT1WMzAlYD6Nf8S4moM0j1yD5JVatEpC+AhSKy2i5UVRWJLreXEHw5zAKAU089tV2VJaK2S+nIrqpVwf87ATyHxFLNO0SkDACC/3cmee0cVY2paqy0tDQ9tSai49bqkV1EugHIUdUDQXwpgB8BeAHADAD3Bv8/n8mK0snJuSwX6UdLjrnEdoE1KcUFkQkqUtXSenE+SKUZ3w/Ac8FY5zwAT6jqAhF5H8DTIjITwCYA12aumkTUXq0mu6quBzCmmed3A7g4E5UiovTjCDrqNDLdtD6R5n/PBL9/eiKPMNmJPMFkJ/IEk53IE0x2Ik8w2Yk8wWQn8gSTncgTTHYiTzDZiTzBZCfyBJOdyBNMdiJPMNmJPMFkJ/IEk53IE0x2Ik8w2Yk8wWQn8gSTncgTTHYiTzDZiTzBZCfyBJOdyBMpJbuIFIvIsyKyWkQqReQ8EektIgtFZG3wf0nr70REHSXVI/uDABao6kgkloKqBHAngEWqOgLAouAxEXVSrSa7iPQCcCGAhwFAVRtUtRbAdABzg83mArgyM1UkonRI5cg+FEANgN+LyDIR+V2wdHM/Vd0ebFONxGqvRNRJpZLseQDOAfCQqo4DcBCRJruqKgBt7sUiMktEKkSkoqampr31JaI2SiXZtwLYqqqLg8fPIpH8O0SkDACC/3c292JVnaOqMVWNlZaWpqPORNQGrSa7qlYD2CIiZwRPXQxgFYAXAMwInpsB4PmM1JCI0iLV9dlvBvC4iBQAWA/gK0h8UTwtIjMBbAJwbWaqSETpkFKyq+qHAGLNFF2c1toQUcZwBB2RJ5jsRJ5gshN5gslO5AkmO5EnmOxEnmCyE3lCEsPas7QzkRokBuD0AbAraztuXmeoA8B6RLEeruOtx2BVbXZcelaTPdypSIWqNjdIx6s6sB6sRzbrwWY8kSeY7ESe6Khkn9NB+7V1hjoArEcU6+FKWz06pM9ORNnHZjyRJ7Ka7CIyTUTWiMg6EcnabLQi8oiI7BSRFdZzWZ8KW0QGicjrIrJKRFaKyC0dURcR6SIiS0Tko6AedwfPDxWRxcHn81Qwf0HGiUhuML/h/I6qh4hsFJGPReRDEakInuuIv5GMTduetWQXkVwA/wPgcwBGAbheREZlafePApgWea4jpsJuBHCbqo4CMBHAt4LfQbbrUg9giqqOATAWwDQRmQjgPgC/VtXhAPYCmJnhehxzCxLTkx/TUfWYrKpjrUtdHfE3krlp21U1K/8AnAfgZevxbACzs7j/IQBWWI/XACgL4jIAa7JVF6sOzwOY2pF1AdAVwAcAJiAxeCOvuc8rg/svD/6ApwCYD0A6qB4bAfSJPJfVzwVALwAbEJxLS3c9stmMHwhgi/V4a/BcR+nQqbBFZAiAcQAWd0Rdgqbzh0hMFLoQwCcAalW1MdgkW5/PAwBuBxAPHp/SQfVQAK+IyFIRmRU8l+3PJaPTtvMEHVqeCjsTRKQ7gD8C+I6q7u+Iuqhqk6qOReLIOh7AyEzvM0pErgCwU1WXZnvfzZikqucg0c38lohcaBdm6XNp17TtrclmslcBGGQ9Lg+e6ygpTYWdbiKSj0SiP66qf+rIugCAJlb3eR2J5nKxiByblzAbn89nAHxBRDYCmIdEU/7BDqgHVLUq+H8ngOeQ+ALM9ufSrmnbW5PNZH8fwIjgTGsBgOuQmI66o2R9KmwRESSW0apU1V91VF1EpFREioO4CInzBpVIJP012aqHqs5W1XJVHYLE38NrqnpjtushIt1EpMexGMClAFYgy5+LZnra9kyf+IicaLgMwN+R6B/+MIv7fRLAdgBHkfj2nIlE33ARgLUAXgXQOwv1mIREE2w5gA+Df5dluy4AzgawLKjHCgD/GTw/DMASAOsAPAOgMIuf0UUA5ndEPYL9fRT8W3nsb7OD/kbGAqgIPps/AyhJVz04go7IEzxBR+QJJjuRJ5jsRJ5gshN5gslO5AkmO5EnmOxEnmCyE3ni/wETE8t04/X9pgAAAABJRU5ErkJggg==",
|
59 |
+
"text/plain": [
|
60 |
+
"<Figure size 432x288 with 1 Axes>"
|
61 |
+
]
|
62 |
+
},
|
63 |
+
"metadata": {
|
64 |
+
"needs_background": "light"
|
65 |
+
}
|
66 |
+
}
|
67 |
+
],
|
68 |
+
"metadata": {}
|
69 |
+
},
|
70 |
+
{
|
71 |
+
"cell_type": "code",
|
72 |
+
"execution_count": 5,
|
73 |
+
"source": [
|
74 |
+
"from jax_nca.nca import NCA"
|
75 |
+
],
|
76 |
+
"outputs": [],
|
77 |
+
"metadata": {}
|
78 |
+
},
|
79 |
+
{
|
80 |
+
"cell_type": "markdown",
|
81 |
+
"source": [
|
82 |
+
"### NCA\n",
|
83 |
+
"- num_hidden_channels = 16\n",
|
84 |
+
"- num_target_channels = 3\n",
|
85 |
+
"- cell_fire_rate = 1.0 (100% chance for cells to be updated)\n",
|
86 |
+
"- alpha_living_threshold = 0.1 (threshold for cells to be alive)"
|
87 |
+
],
|
88 |
+
"metadata": {}
|
89 |
+
},
|
90 |
+
{
|
91 |
+
"cell_type": "code",
|
92 |
+
"execution_count": 6,
|
93 |
+
"source": [
|
94 |
+
"nca = NCA(16, 3, trainable_perception=False, cell_fire_rate=1.0, alpha_living_threshold=0.1)"
|
95 |
+
],
|
96 |
+
"outputs": [],
|
97 |
+
"metadata": {}
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"cell_type": "code",
|
101 |
+
"execution_count": 7,
|
102 |
+
"source": [
|
103 |
+
"from jax_nca.trainer import EmojiTrainer"
|
104 |
+
],
|
105 |
+
"outputs": [],
|
106 |
+
"metadata": {}
|
107 |
+
},
|
108 |
+
{
|
109 |
+
"cell_type": "code",
|
110 |
+
"execution_count": 8,
|
111 |
+
"source": [
|
112 |
+
"trainer = EmojiTrainer(dataset, nca, n_damage=0)"
|
113 |
+
],
|
114 |
+
"outputs": [],
|
115 |
+
"metadata": {}
|
116 |
+
},
|
117 |
+
{
|
118 |
+
"cell_type": "code",
|
119 |
+
"execution_count": 9,
|
120 |
+
"source": [
|
121 |
+
"# trainer.train(100000, batch_size=8, seed=10, lr=2e-4, min_steps=64, max_steps=96)"
|
122 |
+
],
|
123 |
+
"outputs": [],
|
124 |
+
"metadata": {}
|
125 |
+
},
|
126 |
+
{
|
127 |
+
"cell_type": "markdown",
|
128 |
+
"source": [
|
129 |
+
"#### Get current state from trainer"
|
130 |
+
],
|
131 |
+
"metadata": {}
|
132 |
+
},
|
133 |
+
{
|
134 |
+
"cell_type": "code",
|
135 |
+
"execution_count": 10,
|
136 |
+
"source": [
|
137 |
+
"state = trainer.state"
|
138 |
+
],
|
139 |
+
"outputs": [],
|
140 |
+
"metadata": {}
|
141 |
+
},
|
142 |
+
{
|
143 |
+
"cell_type": "code",
|
144 |
+
"execution_count": 11,
|
145 |
+
"source": [
|
146 |
+
"# save\n",
|
147 |
+
"# nca.save(state.params, \"saved_params\")"
|
148 |
+
],
|
149 |
+
"outputs": [],
|
150 |
+
"metadata": {}
|
151 |
+
},
|
152 |
+
{
|
153 |
+
"cell_type": "code",
|
154 |
+
"execution_count": 12,
|
155 |
+
"source": [
|
156 |
+
"params = nca.load(\"gecko_100_cell_fire_rate\")"
|
157 |
+
],
|
158 |
+
"outputs": [],
|
159 |
+
"metadata": {}
|
160 |
+
},
|
161 |
+
{
|
162 |
+
"cell_type": "code",
|
163 |
+
"execution_count": 13,
|
164 |
+
"source": [
|
165 |
+
"import numpy as np\n",
|
166 |
+
"import matplotlib\n",
|
167 |
+
"import matplotlib.pyplot as plt\n",
|
168 |
+
"%matplotlib ipympl\n",
|
169 |
+
"\n",
|
170 |
+
"plt.style.use('ggplot')\n",
|
171 |
+
"# Imports specifically so we can render outputs in Jupyter.\n",
|
172 |
+
"from JSAnimation.IPython_display import display_animation\n",
|
173 |
+
"from matplotlib import animation\n",
|
174 |
+
"from IPython.display import display\n",
|
175 |
+
"from celluloid import Camera\n",
|
176 |
+
"from IPython.display import HTML\n",
|
177 |
+
"import jax\n",
|
178 |
+
"import jax.numpy as jnp\n",
|
179 |
+
"\n",
|
180 |
+
"def render_nca_steps(nca, params, shape = (64, 64), num_steps = 2):\n",
|
181 |
+
" nca_seed = nca.create_seed(nca.num_hidden_channels, nca.num_target_channels, shape=shape, batch_size=1)\n",
|
182 |
+
" rng = jax.random.PRNGKey(0)\n",
|
183 |
+
" _, outputs = nca.multi_step(params, nca_seed, rng, num_steps=num_steps)\n",
|
184 |
+
" stacked = jnp.squeeze(jnp.stack(outputs))\n",
|
185 |
+
" rgbs = np.array(nca.to_rgb(stacked))\n",
|
186 |
+
"\n",
|
187 |
+
" fig = plt.figure(\"Animation\",figsize=(7,5))\n",
|
188 |
+
" camera = Camera(fig)\n",
|
189 |
+
" ax = fig.add_subplot(111)\n",
|
190 |
+
" frames = []\n",
|
191 |
+
" for r in rgbs:\n",
|
192 |
+
" frame = ax.imshow(r)\n",
|
193 |
+
" ax.axis('off')\n",
|
194 |
+
" camera.snap()\n",
|
195 |
+
" frames.append([frame])\n",
|
196 |
+
" animation = camera.animate(blit=False, interval=50)\n",
|
197 |
+
" animation.save('gecko.mp4')\n",
|
198 |
+
" return animation, outputs, rgbs"
|
199 |
+
],
|
200 |
+
"outputs": [],
|
201 |
+
"metadata": {}
|
202 |
+
},
|
203 |
+
{
|
204 |
+
"cell_type": "code",
|
205 |
+
"execution_count": 14,
|
206 |
+
"source": [
|
207 |
+
"animation, outputs, rgbs = render_nca_steps(nca, params, num_steps=256)"
|
208 |
+
],
|
209 |
+
"outputs": [
|
210 |
+
{
|
211 |
+
"output_type": "display_data",
|
212 |
+
"data": {
|
213 |
+
"application/vnd.jupyter.widget-view+json": {
|
214 |
+
"version_major": 2,
|
215 |
+
"version_minor": 0,
|
216 |
+
"model_id": "c0e2eeba79a046e1a9e1b56275e1c911"
|
217 |
+
},
|
218 |
+
"text/html": [
|
219 |
+
"\n",
|
220 |
+
" <div style=\"display: inline-block;\">\n",
|
221 |
+
" <div class=\"jupyter-widgets widget-label\" style=\"text-align: center;\">\n",
|
222 |
+
" Animation\n",
|
223 |
+
" </div>\n",
|
224 |
+
" <img src='data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAArwAAAH0CAYAAADfWf7fAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAZiElEQVR4nO3da5CdB33f8efs2dtZ7a7usmTJsnzD2JDY1Da1zd1JS4ZJgdIEDDEeksk4aWBat9M2tA2UAiE0wJRpMi6lbQopbV2mIVxKU8CkwfgSjI2NjS2IL9gW1sWyZGm1q72ePX3RN5nJ78lEjF3Zf30+L7+Wzh6tNKufnpn9uzMYDAYNAAAUNXSy3wAAADybDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKG34ZL8BgD+vvzqIfXGwFPtQy7/bx4ZGYu90frz3BcDzlye8AACUZvACAFCawQsAQGkGLwAApRm8AACU5koDcFKstPSlzmrs777v9/JPGOQf/8EX/1LsQ5387/zJ7ljLOwLg+c4TXgAASjN4AQAozeAFAKA0gxcAgNIMXgAASnOlAXhWtV1jaPvX9q/c+dHYP7P/T2IfHR6P/dbjfxb71y77QOyrLVcajvTnYt/QXRM7AM89nvACAFCawQsAQGkGLwAApRm8AACUZvACAFBaZzAYDE72mwDqWm3pn997a+xv/85HYh8aH4199vhs/gDD+d/zb97+stjff94vxH5+b2d+/U7OADz3eMILAEBpBi8AAKUZvAAAlGbwAgBQmsELAEBprjQAz4iFJn8pWe0vx77tK2/LrzPSctdhqR9zfyT/u31qJff5lrsRn7zonbG/ZfsrY2+7PtFr8jUJAE4eT3gBACjN4AUAoDSDFwCA0gxeAABKM3gBACht+GS/AaCG8aYT+w17bop9rr8Y++hy/nf4Qn+p5ePmqwgLs/OxD031Yn/Xdz8R+1WnXRz7juFNsZ+o1ZbrFu3y5/lEf/SJvQrA85snvAAAlGbwAgBQmsELAEBpBi8AAKUZvAAAlOZKA3BC2m4KtPXfefCLsXefXol9ea7lesO68dhXWq49NE/MxDz2hfvy6/zk5tg/tutzsf/mhdfGPtHk97na9GM/unI89pnF3A8uHYt94/ja2Hf18q8L4FTiCS8AAKUZvAAAlGbwAgBQmsELAEBpBi8AAKW50gCckLZrDI/O7Yv90OLh2Dddf0fss6PLsS/9k4vz+zlzIvbukXzlYHkmXzno7l6N/SsP3xb7Ry68Lvamn19nKB9paC7941/LLzPWjX1hkK9bXLv11bF/6MJfzu8nv81myGMQoCBf2gAAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM2VBuCEdFquBBzr5H5w4enYh8fzj1+/bWPsM+vyl6uhybHcLz0t9qWWqw7NYj6jcLRZyD++xWx3KfbvHn4w9jPGN8T+xEr+vI22fNn+T4/flF9n7qnY/8Ml/yj2nr8WgII84QUAoDSDFwCA0gxeAABKM3gBACjN4AUAoDTfjguckEHLl401w73Yx3vTsS9/+lWxH39qkD/wwvGYp1YWYx9bl68xHO7kl+8fz9cV5non9lzgwaUDsf/yff869r1PPxn7Uid/Hpa7+eNuGpmK/Q2bXxr7kMcdwCnElzwAAEozeAEAKM3gBQCgNIMXAIDSDF4AAEpzpQE4IYOW6wGnjeRrDOeNboz94dHZ2Ec3zsc+cij/+/y01Zibbj//h3Xr8/WGx1bz+Ya18/njzi7PxX7R6I7Yb7z412P/4Lf/fexfPHxn7KdNbYn9ziv/TezrpzfE3l1oOVcxnjPA85knvAAAlGbwAgBQmsELAEBpBi8AAKUZvAAAlOZKA3BCuk3+7v5eyzf9v3HbFbF/YP/n8+sP9WJftzlfXXjJzELsE938hh5r+rF3pvN5gtkjy7FPrXTz64zkKxYXbDg79s+88n2x7/jPr4t9op/fz+a1m2If7rQ813iWrzHMDfL7HGp5P6NN/ny2PZVp+ePWDFbzn5POkOc7cCrzFQAAgNIMXgAASjN4AQAozeAFAKA0gxcAgNJcaQCeEcP9/O/nd537+tj/YO/NsT+wvC/2pZbvsr+kpV9xcCV/3DPyNYD13dzvH12K/YHlJ2K/sHdW7J0mXy24Zc/dsc89fTT2wWr+9R6YOxT79snNsT/bhlrOKDw8uyf2R+f3xz45nK92LK7k35dXrntx7L2W6x/AqcETXgAASjN4AQAozeAFAKA0gxcAgNIMXgAASusMBoP8P34HeAYMmvwl5pGj+crBxTf/auz9yfz6Fyytxn71lqnYp2dmYt/fcu3h7mY89kdX18R+/Xlvjf3+Pbtj/8yt/yP2g8fy1YXpiXWxP3T912LfMDId+4ma7y/GPtTNx34WmnxF4bo7Pxb7l578Vv7Ag/z7u66ZiP3el98Q+8bpTbF3Oi3nJIBSPOEFAKA0gxcAgNIMXgAASjN4AQAozeAFAKA0VxqAZ9VKS58Z5O/iPzC7N/Yrbn5n7CNNvh5wxbpe7H9nNL+fC/fNx37PmZtj/29z+Vf2wIG52OdH8rWBxaePxX7OsXyW4n++Pl8hOGPrzthXhvMVgvEmX1doM+jnvyq+fuCO2P/2t9+X+6aXxb5n9XDsc4v592W25QrEttV8neMLr/pg7JNjLec/gFI84QUAoDSDFwCA0gxeAABKM3gBACjN4AUAoLQT+zZdgBZt517ybYKmmejk6wFnTm2P/ZbXfCL219x0Xez7+vmKwsGRNbGPHM/XAF53x2OxP3rJ6bEvr81XAh6Zn419Zku+JvHIjunYP71wc+zXPP3q2LdtPjN/3JYrB5Or+a+Ffjf/Du+Y3BL7lt6G2O/q58/n4YV8rWIwyH9OxkbzuY29zUzsQ0Oe78CpzFcAAABKM3gBACjN4AUAoDSDFwCA0gxeAABK6wwGg7Zvrgb4K+u39Pff/6nYx5v83ffvv/+/xP7km78c+wd2/37sHz/4udivXTMW+yXd/CtYu2Uy9qPH5mKfnR2J/fa55djvWcx9ppuvEBwfyZ+3yXycoPnKlf8q9rN7W2O/+cDdsf/xnjtiP218few71uTX//juP4j9e8P7Y1/Kn4bmN85+S+z/4Iw3xD7VWxv7cMufQ6AWT3gBACjN4AUAoDSDFwCA0gxeAABKM3gBACgt/0/TAVqs9PNVgZVOvnLw2Yf/d+zHVuZjP723Lvbh1YXYrz7rqtg/vTd/3HuO5dfprc3frX/hYv51vfR7h2M///F8LmHyyrNy7+QvwwtPHo/94al8ZeLekaXY3/Sdfxb7h8/6u7H//T/8h7EfGMrvZ2J4OvZObyr25S357MLSaP78XzZ9QezX78jXGCbG81UN1xjg1OYJLwAApRm8AACUZvACAFCawQsAQGkGLwAApbnSAJyQoU439jfd8huxbxxbG/uh5dnY9/WPxP6NfffG/tPbr4x9Wz9fD3h0ZSX2DcfzNYbtE4PYn1o7EfvwRfnjnnVuy49fGYn9Z277Qey7N+QrBO95Yb5+sKflGsZbH/hw7J1mMfZmIV/nmBjkz9v568+I/chwfs7S6+Vf1zVbXpE/7lgv9rGWqxfAqc0TXgAASjN4AQAozeAFAKA0gxcAgNIMXgAASvPtrMCJyUcLmtfvfFXsv/vg52Kfy0cFmsHxfCXgS0/cHvtQpxP7Az/6fuy9Tevyj5/I3/U/fSRfdeidPhX77Pqx2F9w31OxL1+6I/a9l26Mfe14/g147XS+cnDXbL7S8O2hfB1i+bStsc88cSj2G3/+htiv2PSi2Jux/HG7w7kPNfn3d6jreQ3wV+crBgAApRm8AACUZvACAFCawQsAQGkGLwAApXUGg0HL91wD/EWry/3YH+/n7+L/iZt+JfaXTp8T+91P7Y796bn8+pOzq7Ef7+brCpu27oz9qen869qwmvtPTOQzE780nI/fXP3FP4v9a792fuzzs/k6wZbF/JziaMvji6Xd+fP2neX8pf9TW9fEfmjfsdi/9ap8peGcjWfEPjGaXx/g2eQJLwAApRm8AACUZvACAFCawQsAQGkGLwAApRm8AACUlu/nALQZ6ca8uTMV+yde8vdi/7nTXhb7v33oD2N/93f/XX4/q7M5H5+L/UOv+YXYN27K58re/KcfjP3Ohfxxhyd7se/+2XyGbXIlfxm+qJ9f/6Iv7In9lp//ydhf0clnwDa2nJf707F8bu0HuzbHftfCQ7G/oHtm7AAngye8AACUZvACAFCawQsAQGkGLwAApRm8AACU1hkMBoOT/SaA57/FwUrsQ51O7EuDfCVgtJ//HX7NHflawm0H7439F7f9zdjfffE1+f108/WJe44+GPvPfOO9sfd25WsVEzMHY3/h6HTs5w4txn7l6ELsm8dOj70/nT/Pex85Gvuefv79ums0X2842l8X+22vuCH21ZY/D0NN/vwDPBM84QUAoDSDFwCA0gxeAABKM3gBACjN4AUAoDRXGoDnlJUmf0labekr/eXY265DjA3lawP9Qb6KsLqanwscHByL/cYHvhj7R777+7Ef3ZLfz+n56EVz4fa1sb+ytxT7mWO92LvH8uftnN3zsX9y03Ds3xzP7/+f77wu9rfs+KnYO03+/QJ4JnjCCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaa40APw4VvKVg73H98f+6hvfEfuDg4OxD593Ruyblruxv2IqP7/46XX5SsMZx/NVihc90Y/966P54/7XdfmvkInBubF//q//ZuydoZHYm5ZrGwAnwhNeAABKM3gBACjN4AUAoDSDFwCA0gxeAABKy/9zdAD+Usv9pdjf/NVfj/1HK8di787n1+kdOB771m35+sE3534U+0LLkYM3rs/POxbOH419eF++6rB2Il9XuPfwI7E/uvhk7LsmdsTuRgPwTPCEFwCA0gxeAABKM3gBACjN4AUAoDSDFwCA0lxpAGiaZqFZjX11OV9LmBlaiP2t57429rtn8xWF/syh2C/YcFbsn738fbHfPv947O958OOxr5+dj/1vjfZjn1g7EfvUsfz5aXr5ecp9R34Y+66J7fl1BjnPr+arEb3ueP4JwCnNE14AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEpzpQGgaZql/lLsb7/rw7H/3PaXxb5pakvs3bF8bmD9ztNjv395f+xvuPVfxn7bT90Q+94zron9U9//ROyXr52O/W/88Ejs/2fjaOzjw53Y7zv2aOyvb14e+5eevDP2l/byFYvetCsNwF/kCS8AAKUZvAAAlGbwAgBQmsELAEBpBi8AAKW50gDQNM3KoB/7zfvuif3x2Sdjf+ypH8be6Y3EvjzI1xsW8w9vHmq53jDf5CsT1257bewfvf8/xv7thYX8gU/Lf10MtTw3mVjIn88fze7Nr9/isqlzY984vOaEXgc4tXnCCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaa40ADRNMzo8Fvs/ffHbY/9fB+6Ifc3a9bE/sXwo9omJjbH/9tlvi/3qs14d+1jLl/ORfu5XT14e+5cHt+XXGYzGvnRsLvbVtfmKwtzqfOz9wWrsWyfy5xPgRHjCCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaa40ADRNM9J0Y//ko38U+97FfHXh+nPeEPt/f+KbsZ89vTP2q6ZfHHt3sRP75Hi+orDazdcPrr3o6thvvOkbsd+/Pj8f2bBxbewjy8uxz3fy6wwt9WNfHs2/L0urS7GPdvOPb/v9PbaUr0xMjeYrE8Dzkye8AACUZvACAFCawQsAQGkGLwAApRm8AACU5koDQNM0nWYQ+5+8/GOx33b0/tjfcedHY79o/dmx75s7EPvOqa2x98bGY2+z1PLruqC3K/Zdo9tif7h/JPa9M0/HvjIxld/Pkf2xz4/F3Ay3vP8DC/lKxg8W98X+wvHTY98+sSn2hdV8ZaJpmmZ8aKT1vwHPTZ7wAgBQmsELAEBpBi8AAKUZvAAAlGbwAgBQmisNAE3TDOdjAM3IWP4yeV4vf9f/G7ddGfvOsc2xH1o+lj9utxv7eHNiFwLGm/w6yy3XD9552bWx/+Jtvx37+pbHJjNzx2MfDD8V+/Hl+diP9g/Hfu6t1+UPPNLy19pS/vVev/11sX/oRe/Ir9M0TX+QX6vb6bT+HODk8oQXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSOoNBy7ebAtAsDJZj77R8R/5Q09bz84W2L8BtX5lHnqFLACstH3m56cd+4dffHvvMfL4ysbK0FPvi+Hjsn7/8vbFfPnl+7N9f3Bf7LYd2x/6PH/y92JuV/Ov96sX/Iv/4pmmu2npp7K40wHOXJ7wAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJTW8j8dB6Bpmma8M3JyPvCz/A3/w63XJPJfCx99ybtif9utvxX7+jVTsc8cz1cdbtjzR7Hv2Lkh9sunzo790u3nxP5bD3829sNDM7Hf8/TDsTdN01y19ZKW/+JKAzxXecILAEBpBi8AAKUZvAAAlGbwAgBQmsELAEBprjQA8Of0Y/3ZDZfF/qYzr4j9xv3fir3bz5cMvvTYrbF38ttpdnXWxn5g7kjswy0HFDr9Qexbevk6RNM0TWeQf44jDfDc5QkvAAClGbwAAJRm8AIAUJrBCwBAaQYvAACldQaDtm83BYD/59hgPvZOJ58muOKrvxr7QyuHY184nl9/bHIy9olON/a5ZiX2TsvfdC/obYv9yy95b/4JTdNsndwS+8hQfk/AyecJLwAApRm8AACUZvACAFCawQsAQGkGLwAApbnSAMCPbbHlKsKhxUOxv+n298f+vYX9sc/NHYl9TS9fb+jmoxHN5qG1sX/hkvfEfu76XfmFmqYZ7o7kj936M4CTzRNeAABKM3gBACjN4AUAoDSDFwCA0gxeAABKc6UBgB/bwmAh9k4n3ywYabllcPfRR2Lff2hf7Lcv5B9/3uTpsb9uw1+LfcOaDbH/ZX8zDg+1nIIAnrM84QUAoDSDFwCA0gxeAABKM3gBACjN4AUAoDRXGgD4/6bf9GNfbfmbaGZ1PvZ1q2P59bv5gsJwy3WIIRcX4JTgCS8AAKUZvAAAlGbwAgBQmsELAEBpBi8AAKW50gAAQGme8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaf8XdahW29FwEN8AAAAASUVORK5CYII=' width=700.0/>\n",
|
225 |
+
" </div>\n",
|
226 |
+
" "
|
227 |
+
],
|
228 |
+
"image/png": "iVBORw0KGgoAAAANSUhEUgAAArwAAAH0CAYAAADfWf7fAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAZiElEQVR4nO3da5CdB33f8efs2dtZ7a7usmTJsnzD2JDY1Da1zd1JS4ZJgdIEDDEeksk4aWBat9M2tA2UAiE0wJRpMi6lbQopbV2mIVxKU8CkwfgSjI2NjS2IL9gW1sWyZGm1q72ePX3RN5nJ78lEjF3Zf30+L7+Wzh6tNKufnpn9uzMYDAYNAAAUNXSy3wAAADybDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEozeAEAKG34ZL8BgD+vvzqIfXGwFPtQy7/bx4ZGYu90frz3BcDzlye8AACUZvACAFCawQsAQGkGLwAApRm8AACU5koDcFKstPSlzmrs777v9/JPGOQf/8EX/1LsQ5387/zJ7ljLOwLg+c4TXgAASjN4AQAozeAFAKA0gxcAgNIMXgAASnOlAXhWtV1jaPvX9q/c+dHYP7P/T2IfHR6P/dbjfxb71y77QOyrLVcajvTnYt/QXRM7AM89nvACAFCawQsAQGkGLwAApRm8AACUZvACAFBaZzAYDE72mwDqWm3pn997a+xv/85HYh8aH4199vhs/gDD+d/zb97+stjff94vxH5+b2d+/U7OADz3eMILAEBpBi8AAKUZvAAAlGbwAgBQmsELAEBprjQAz4iFJn8pWe0vx77tK2/LrzPSctdhqR9zfyT/u31qJff5lrsRn7zonbG/ZfsrY2+7PtFr8jUJAE4eT3gBACjN4AUAoDSDFwCA0gxeAABKM3gBACht+GS/AaCG8aYT+w17bop9rr8Y++hy/nf4Qn+p5ePmqwgLs/OxD031Yn/Xdz8R+1WnXRz7juFNsZ+o1ZbrFu3y5/lEf/SJvQrA85snvAAAlGbwAgBQmsELAEBpBi8AAKUZvAAAlOZKA3BC2m4KtPXfefCLsXefXol9ea7lesO68dhXWq49NE/MxDz2hfvy6/zk5tg/tutzsf/mhdfGPtHk97na9GM/unI89pnF3A8uHYt94/ja2Hf18q8L4FTiCS8AAKUZvAAAlGbwAgBQmsELAEBpBi8AAKW50gCckLZrDI/O7Yv90OLh2Dddf0fss6PLsS/9k4vz+zlzIvbukXzlYHkmXzno7l6N/SsP3xb7Ry68Lvamn19nKB9paC7941/LLzPWjX1hkK9bXLv11bF/6MJfzu8nv81myGMQoCBf2gAAKM3gBQCgNIMXAIDSDF4AAEozeAEAKM2VBuCEdFquBBzr5H5w4enYh8fzj1+/bWPsM+vyl6uhybHcLz0t9qWWqw7NYj6jcLRZyD++xWx3KfbvHn4w9jPGN8T+xEr+vI22fNn+T4/flF9n7qnY/8Ml/yj2nr8WgII84QUAoDSDFwCA0gxeAABKM3gBACjN4AUAoDTfjguckEHLl401w73Yx3vTsS9/+lWxH39qkD/wwvGYp1YWYx9bl68xHO7kl+8fz9cV5non9lzgwaUDsf/yff869r1PPxn7Uid/Hpa7+eNuGpmK/Q2bXxr7kMcdwCnElzwAAEozeAEAKM3gBQCgNIMXAIDSDF4AAEpzpQE4IYOW6wGnjeRrDOeNboz94dHZ2Ec3zsc+cij/+/y01Zibbj//h3Xr8/WGx1bz+Ya18/njzi7PxX7R6I7Yb7z412P/4Lf/fexfPHxn7KdNbYn9ziv/TezrpzfE3l1oOVcxnjPA85knvAAAlGbwAgBQmsELAEBpBi8AAKUZvAAAlOZKA3BCuk3+7v5eyzf9v3HbFbF/YP/n8+sP9WJftzlfXXjJzELsE938hh5r+rF3pvN5gtkjy7FPrXTz64zkKxYXbDg79s+88n2x7/jPr4t9op/fz+a1m2If7rQ813iWrzHMDfL7HGp5P6NN/ny2PZVp+ePWDFbzn5POkOc7cCrzFQAAgNIMXgAASjN4AQAozeAFAKA0gxcAgNJcaQCeEcP9/O/nd537+tj/YO/NsT+wvC/2pZbvsr+kpV9xcCV/3DPyNYD13dzvH12K/YHlJ2K/sHdW7J0mXy24Zc/dsc89fTT2wWr+9R6YOxT79snNsT/bhlrOKDw8uyf2R+f3xz45nK92LK7k35dXrntx7L2W6x/AqcETXgAASjN4AQAozeAFAKA0gxcAgNIMXgAASusMBoP8P34HeAYMmvwl5pGj+crBxTf/auz9yfz6Fyytxn71lqnYp2dmYt/fcu3h7mY89kdX18R+/Xlvjf3+Pbtj/8yt/yP2g8fy1YXpiXWxP3T912LfMDId+4ma7y/GPtTNx34WmnxF4bo7Pxb7l578Vv7Ag/z7u66ZiP3el98Q+8bpTbF3Oi3nJIBSPOEFAKA0gxcAgNIMXgAASjN4AQAozeAFAKA0VxqAZ9VKS58Z5O/iPzC7N/Yrbn5n7CNNvh5wxbpe7H9nNL+fC/fNx37PmZtj/29z+Vf2wIG52OdH8rWBxaePxX7OsXyW4n++Pl8hOGPrzthXhvMVgvEmX1doM+jnvyq+fuCO2P/2t9+X+6aXxb5n9XDsc4v592W25QrEttV8neMLr/pg7JNjLec/gFI84QUAoDSDFwCA0gxeAABKM3gBACjN4AUAoLQT+zZdgBZt517ybYKmmejk6wFnTm2P/ZbXfCL219x0Xez7+vmKwsGRNbGPHM/XAF53x2OxP3rJ6bEvr81XAh6Zn419Zku+JvHIjunYP71wc+zXPP3q2LdtPjN/3JYrB5Or+a+Ffjf/Du+Y3BL7lt6G2O/q58/n4YV8rWIwyH9OxkbzuY29zUzsQ0Oe78CpzFcAAABKM3gBACjN4AUAoDSDFwCA0gxeAABK6wwGg7Zvrgb4K+u39Pff/6nYx5v83ffvv/+/xP7km78c+wd2/37sHz/4udivXTMW+yXd/CtYu2Uy9qPH5mKfnR2J/fa55djvWcx9ppuvEBwfyZ+3yXycoPnKlf8q9rN7W2O/+cDdsf/xnjtiP218few71uTX//juP4j9e8P7Y1/Kn4bmN85+S+z/4Iw3xD7VWxv7cMufQ6AWT3gBACjN4AUAoDSDFwCA0gxeAABKM3gBACgt/0/TAVqs9PNVgZVOvnLw2Yf/d+zHVuZjP723Lvbh1YXYrz7rqtg/vTd/3HuO5dfprc3frX/hYv51vfR7h2M///F8LmHyyrNy7+QvwwtPHo/94al8ZeLekaXY3/Sdfxb7h8/6u7H//T/8h7EfGMrvZ2J4OvZObyr25S357MLSaP78XzZ9QezX78jXGCbG81UN1xjg1OYJLwAApRm8AACUZvACAFCawQsAQGkGLwAApbnSAJyQoU439jfd8huxbxxbG/uh5dnY9/WPxP6NfffG/tPbr4x9Wz9fD3h0ZSX2DcfzNYbtE4PYn1o7EfvwRfnjnnVuy49fGYn9Z277Qey7N+QrBO95Yb5+sKflGsZbH/hw7J1mMfZmIV/nmBjkz9v568+I/chwfs7S6+Vf1zVbXpE/7lgv9rGWqxfAqc0TXgAASjN4AQAozeAFAKA0gxcAgNIMXgAASvPtrMCJyUcLmtfvfFXsv/vg52Kfy0cFmsHxfCXgS0/cHvtQpxP7Az/6fuy9Tevyj5/I3/U/fSRfdeidPhX77Pqx2F9w31OxL1+6I/a9l26Mfe14/g147XS+cnDXbL7S8O2hfB1i+bStsc88cSj2G3/+htiv2PSi2Jux/HG7w7kPNfn3d6jreQ3wV+crBgAApRm8AACUZvACAFCawQsAQGkGLwAApXUGg0HL91wD/EWry/3YH+/n7+L/iZt+JfaXTp8T+91P7Y796bn8+pOzq7Ef7+brCpu27oz9qen869qwmvtPTOQzE780nI/fXP3FP4v9a792fuzzs/k6wZbF/JziaMvji6Xd+fP2neX8pf9TW9fEfmjfsdi/9ap8peGcjWfEPjGaXx/g2eQJLwAApRm8AACUZvACAFCawQsAQGkGLwAApRm8AACUlu/nALQZ6ca8uTMV+yde8vdi/7nTXhb7v33oD2N/93f/XX4/q7M5H5+L/UOv+YXYN27K58re/KcfjP3Ohfxxhyd7se/+2XyGbXIlfxm+qJ9f/6Iv7In9lp//ydhf0clnwDa2nJf707F8bu0HuzbHftfCQ7G/oHtm7AAngye8AACUZvACAFCawQsAQGkGLwAApRm8AACU1hkMBoOT/SaA57/FwUrsQ51O7EuDfCVgtJ//HX7NHflawm0H7439F7f9zdjfffE1+f108/WJe44+GPvPfOO9sfd25WsVEzMHY3/h6HTs5w4txn7l6ELsm8dOj70/nT/Pex85Gvuefv79ums0X2842l8X+22vuCH21ZY/D0NN/vwDPBM84QUAoDSDFwCA0gxeAABKM3gBACjN4AUAoDRXGoDnlJUmf0labekr/eXY265DjA3lawP9Qb6KsLqanwscHByL/cYHvhj7R777+7Ef3ZLfz+n56EVz4fa1sb+ytxT7mWO92LvH8uftnN3zsX9y03Ds3xzP7/+f77wu9rfs+KnYO03+/QJ4JnjCCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaa40APw4VvKVg73H98f+6hvfEfuDg4OxD593Ruyblruxv2IqP7/46XX5SsMZx/NVihc90Y/966P54/7XdfmvkInBubF//q//ZuydoZHYm5ZrGwAnwhNeAABKM3gBACjN4AUAoDSDFwCA0gxeAABKy/9zdAD+Usv9pdjf/NVfj/1HK8di787n1+kdOB771m35+sE3534U+0LLkYM3rs/POxbOH419eF++6rB2Il9XuPfwI7E/uvhk7LsmdsTuRgPwTPCEFwCA0gxeAABKM3gBACjN4AUAoDSDFwCA0lxpAGiaZqFZjX11OV9LmBlaiP2t57429rtn8xWF/syh2C/YcFbsn738fbHfPv947O958OOxr5+dj/1vjfZjn1g7EfvUsfz5aXr5ecp9R34Y+66J7fl1BjnPr+arEb3ueP4JwCnNE14AAEozeAEAKM3gBQCgNIMXAIDSDF4AAEpzpQGgaZql/lLsb7/rw7H/3PaXxb5pakvs3bF8bmD9ztNjv395f+xvuPVfxn7bT90Q+94zron9U9//ROyXr52O/W/88Ejs/2fjaOzjw53Y7zv2aOyvb14e+5eevDP2l/byFYvetCsNwF/kCS8AAKUZvAAAlGbwAgBQmsELAEBpBi8AAKW50gDQNM3KoB/7zfvuif3x2Sdjf+ypH8be6Y3EvjzI1xsW8w9vHmq53jDf5CsT1257bewfvf8/xv7thYX8gU/Lf10MtTw3mVjIn88fze7Nr9/isqlzY984vOaEXgc4tXnCCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaa40ADRNMzo8Fvs/ffHbY/9fB+6Ifc3a9bE/sXwo9omJjbH/9tlvi/3qs14d+1jLl/ORfu5XT14e+5cHt+XXGYzGvnRsLvbVtfmKwtzqfOz9wWrsWyfy5xPgRHjCCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaa40ADRNM9J0Y//ko38U+97FfHXh+nPeEPt/f+KbsZ89vTP2q6ZfHHt3sRP75Hi+orDazdcPrr3o6thvvOkbsd+/Pj8f2bBxbewjy8uxz3fy6wwt9WNfHs2/L0urS7GPdvOPb/v9PbaUr0xMjeYrE8Dzkye8AACUZvACAFCawQsAQGkGLwAApRm8AACU5koDQNM0nWYQ+5+8/GOx33b0/tjfcedHY79o/dmx75s7EPvOqa2x98bGY2+z1PLruqC3K/Zdo9tif7h/JPa9M0/HvjIxld/Pkf2xz4/F3Ay3vP8DC/lKxg8W98X+wvHTY98+sSn2hdV8ZaJpmmZ8aKT1vwHPTZ7wAgBQmsELAEBpBi8AAKUZvAAAlGbwAgBQmisNAE3TDOdjAM3IWP4yeV4vf9f/G7ddGfvOsc2xH1o+lj9utxv7eHNiFwLGm/w6yy3XD9552bWx/+Jtvx37+pbHJjNzx2MfDD8V+/Hl+diP9g/Hfu6t1+UPPNLy19pS/vVev/11sX/oRe/Ir9M0TX+QX6vb6bT+HODk8oQXAIDSDF4AAEozeAEAKM3gBQCgNIMXAIDSOoNBy7ebAtAsDJZj77R8R/5Q09bz84W2L8BtX5lHnqFLACstH3m56cd+4dffHvvMfL4ysbK0FPvi+Hjsn7/8vbFfPnl+7N9f3Bf7LYd2x/6PH/y92JuV/Ov96sX/Iv/4pmmu2npp7K40wHOXJ7wAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJTW8j8dB6Bpmma8M3JyPvCz/A3/w63XJPJfCx99ybtif9utvxX7+jVTsc8cz1cdbtjzR7Hv2Lkh9sunzo790u3nxP5bD3829sNDM7Hf8/TDsTdN01y19ZKW/+JKAzxXecILAEBpBi8AAKUZvAAAlGbwAgBQmsELAEBprjQA8Of0Y/3ZDZfF/qYzr4j9xv3fir3bz5cMvvTYrbF38ttpdnXWxn5g7kjswy0HFDr9Qexbevk6RNM0TWeQf44jDfDc5QkvAAClGbwAAJRm8AIAUJrBCwBAaQYvAACldQaDtm83BYD/59hgPvZOJ58muOKrvxr7QyuHY184nl9/bHIy9olON/a5ZiX2TsvfdC/obYv9yy95b/4JTdNsndwS+8hQfk/AyecJLwAApRm8AACUZvACAFCawQsAQGkGLwAApbnSAMCPbbHlKsKhxUOxv+n298f+vYX9sc/NHYl9TS9fb+jmoxHN5qG1sX/hkvfEfu76XfmFmqYZ7o7kj936M4CTzRNeAABKM3gBACjN4AUAoDSDFwCA0gxeAABKc6UBgB/bwmAh9k4n3ywYabllcPfRR2Lff2hf7Lcv5B9/3uTpsb9uw1+LfcOaDbH/ZX8zDg+1nIIAnrM84QUAoDSDFwCA0gxeAABKM3gBACjN4AUAoDRXGgD4/6bf9GNfbfmbaGZ1PvZ1q2P59bv5gsJwy3WIIRcX4JTgCS8AAKUZvAAAlGbwAgBQmsELAEBpBi8AAKW50gAAQGme8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaQYvAAClGbwAAJRm8AIAUJrBCwBAaf8XdahW29FwEN8AAAAASUVORK5CYII=",
|
229 |
+
"text/plain": [
|
230 |
+
"Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
|
231 |
+
]
|
232 |
+
},
|
233 |
+
"metadata": {}
|
234 |
+
}
|
235 |
+
],
|
236 |
+
"metadata": {}
|
237 |
+
},
|
238 |
+
{
|
239 |
+
"cell_type": "code",
|
240 |
+
"execution_count": 15,
|
241 |
+
"source": [
|
242 |
+
"from IPython.display import Video\n",
|
243 |
+
"\n",
|
244 |
+
"Video(\"gecko.mp4\")"
|
245 |
+
],
|
246 |
+
"outputs": [
|
247 |
+
{
|
248 |
+
"output_type": "execute_result",
|
249 |
+
"data": {
|
250 |
+
"text/html": [
|
251 |
+
"<video src=\"gecko.mp4\" controls >\n",
|
252 |
+
" Your browser does not support the <code>video</code> element.\n",
|
253 |
+
" </video>"
|
254 |
+
],
|
255 |
+
"text/plain": [
|
256 |
+
"<IPython.core.display.Video object>"
|
257 |
+
]
|
258 |
+
},
|
259 |
+
"metadata": {},
|
260 |
+
"execution_count": 15
|
261 |
+
}
|
262 |
+
],
|
263 |
+
"metadata": {}
|
264 |
+
}
|
265 |
+
],
|
266 |
+
"metadata": {
|
267 |
+
"orig_nbformat": 4,
|
268 |
+
"language_info": {
|
269 |
+
"name": "python",
|
270 |
+
"version": "3.9.0",
|
271 |
+
"mimetype": "text/x-python",
|
272 |
+
"codemirror_mode": {
|
273 |
+
"name": "ipython",
|
274 |
+
"version": 3
|
275 |
+
},
|
276 |
+
"pygments_lexer": "ipython3",
|
277 |
+
"nbconvert_exporter": "python",
|
278 |
+
"file_extension": ".py"
|
279 |
+
},
|
280 |
+
"kernelspec": {
|
281 |
+
"name": "python3",
|
282 |
+
"display_name": "Python 3.9.0 64-bit ('jax_gpu': conda)"
|
283 |
+
},
|
284 |
+
"interpreter": {
|
285 |
+
"hash": "a7271dcc4a91420ffb9cc5ce7ff5a5d83d948f729c0ba20dec48f9a748a86390"
|
286 |
+
}
|
287 |
+
},
|
288 |
+
"nbformat": 4,
|
289 |
+
"nbformat_minor": 2
|
290 |
+
}
|
notebooks/gecko.mp4
ADDED
Binary file (66.5 kB). View file
|
|
notebooks/gecko_100_cell_fire_rate
ADDED
Binary file (37.5 kB). View file
|
|
setup.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
from os import path
|
5 |
+
|
6 |
+
from setuptools import find_packages
|
7 |
+
from setuptools import setup
|
8 |
+
|
9 |
+
|
10 |
+
this_directory = path.abspath(path.dirname(__file__))
|
11 |
+
with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f:
|
12 |
+
long_description = f.read()
|
13 |
+
|
14 |
+
|
15 |
+
setup(
|
16 |
+
name="jax_nca",
|
17 |
+
version="0.1.2",
|
18 |
+
url="https://github.com/shyamsn97/jax-nca",
|
19 |
+
license='MIT',
|
20 |
+
|
21 |
+
author="Shyam Sudhakaran",
|
22 |
+
author_email="shyamsnair@protonmail.com",
|
23 |
+
|
24 |
+
description="Neural Cellular Automata (https://distill.pub/2020/growing-ca/ -- Mordvintsev, et al., \"Growing Neural Cellular Automata\", Distill, 2020) implemented in JAX",
|
25 |
+
|
26 |
+
long_description=long_description,
|
27 |
+
long_description_content_type="text/markdown",
|
28 |
+
|
29 |
+
packages=find_packages(exclude=('tests',)),
|
30 |
+
|
31 |
+
install_requires=[
|
32 |
+
'numpy',
|
33 |
+
'jax',
|
34 |
+
'flax',
|
35 |
+
'matplotlib',
|
36 |
+
'einops',
|
37 |
+
'tensorflow',
|
38 |
+
'tensorboardX',
|
39 |
+
'optax',
|
40 |
+
'tqdm',
|
41 |
+
'pandas',
|
42 |
+
'pillow',
|
43 |
+
'ipycanvas',
|
44 |
+
'orjson',
|
45 |
+
'opencv-python'
|
46 |
+
],
|
47 |
+
|
48 |
+
classifiers=[
|
49 |
+
'Development Status :: 2 - Pre-Alpha',
|
50 |
+
'License :: OSI Approved :: MIT License',
|
51 |
+
'Programming Language :: Python :: 3.9',
|
52 |
+
],
|
53 |
+
)
|