shyamsn97
commited on
Commit
•
f595bfc
1
Parent(s):
434b57f
update readme
Browse files
README.md
CHANGED
@@ -1,32 +1,31 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
---
|
6 |
|
|
|
|
|
7 |
|
8 |
## Installation
|
9 |
-
from source
|
10 |
-
|
|
|
11 |
git clone git@github.com:shyamsn97/jax-nca.git
|
12 |
cd jax-nca
|
13 |
python setup.py install
|
14 |
```
|
15 |
|
16 |
from PYPI
|
17 |
-
|
|
|
18 |
pip install jax-nca
|
19 |
```
|
20 |
-
---
|
21 |
|
22 |
## How do NCAs work?
|
23 |
For more information, view the awesome article https://distill.pub/2020/growing-ca/ -- Mordvintsev, et al., "Growing Neural Cellular Automata", Distill, 2020
|
24 |
|
25 |
Image below describes a single update step: https://github.com/distillpub/post--growing-ca/blob/master/public/figures/model.svg
|
26 |
|
27 |
-
![NCA update](https://raw.githubusercontent.com/shyamsn97/jax-nca/main/images/model.svg?token=GHSAT0AAAAAABTB4G7FOWOPXEUYVLBGRNSWYTB5YUA)
|
28 |
-
|
29 |
-
---
|
30 |
|
31 |
## Why Jax?
|
32 |
|
@@ -34,7 +33,7 @@ Image below describes a single update step: https://github.com/distillpub/post--
|
|
34 |
|
35 |
NCAs are autoregressive models like RNNs, where new states are calculated from previous ones. With jax, we can make these operations a lot more performant with `jax.lax.scan` and `jax.jit` (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html)
|
36 |
|
37 |
-
Instead of writing the nca growth process as
|
38 |
|
39 |
```python
|
40 |
def multi_step(params, nca, current_state, num_steps):
|
@@ -66,14 +65,13 @@ def multi_step(params, nca, current_state, num_steps):
|
|
66 |
```
|
67 |
The actual multi_step implementation can be found here: https://github.com/shyamsn97/jax-nca/blob/main/jax_nca/nca.py#L103
|
68 |
|
69 |
-
---
|
70 |
|
71 |
## Usage
|
72 |
See [notebooks/Gecko.ipynb](notebooks/Gecko.ipynb) for a full example
|
73 |
|
74 |
<b> Currently there's a bug with the stochastic update, so only `cell_fire_rate = 1.0` works at the moment </b>
|
75 |
|
76 |
-
Creating and using NCA
|
77 |
|
78 |
```python
|
79 |
class NCA(nn.Module):
|
@@ -117,7 +115,8 @@ update = nca.apply({"params":params}, nca_seed, jax.random.PRNGKey(10))
|
|
117 |
final_state, nca_states = nca.multi_step(poarams, nca_seed, jax.random.PRNGKey(10), num_steps=32)
|
118 |
```
|
119 |
|
120 |
-
To train the NCA
|
|
|
121 |
```python
|
122 |
from jax_nca.dataset import ImageDataset
|
123 |
from jax_nca.trainer import EmojiTrainer
|
|
|
1 |
+
---
|
2 |
+
tags:
|
3 |
+
- image-generation
|
|
|
4 |
---
|
5 |
|
6 |
+
# Neural Cellular Automata (Based on https://distill.pub/2020/growing-ca/) implemented in Jax (Flax)
|
7 |
+
|
8 |
|
9 |
## Installation
|
10 |
+
from source
|
11 |
+
|
12 |
+
```bash
|
13 |
git clone git@github.com:shyamsn97/jax-nca.git
|
14 |
cd jax-nca
|
15 |
python setup.py install
|
16 |
```
|
17 |
|
18 |
from PYPI
|
19 |
+
|
20 |
+
```bash
|
21 |
pip install jax-nca
|
22 |
```
|
|
|
23 |
|
24 |
## How do NCAs work?
|
25 |
For more information, view the awesome article https://distill.pub/2020/growing-ca/ -- Mordvintsev, et al., "Growing Neural Cellular Automata", Distill, 2020
|
26 |
|
27 |
Image below describes a single update step: https://github.com/distillpub/post--growing-ca/blob/master/public/figures/model.svg
|
28 |
|
|
|
|
|
|
|
29 |
|
30 |
## Why Jax?
|
31 |
|
|
|
33 |
|
34 |
NCAs are autoregressive models like RNNs, where new states are calculated from previous ones. With jax, we can make these operations a lot more performant with `jax.lax.scan` and `jax.jit` (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html)
|
35 |
|
36 |
+
Instead of writing the nca growth process as
|
37 |
|
38 |
```python
|
39 |
def multi_step(params, nca, current_state, num_steps):
|
|
|
65 |
```
|
66 |
The actual multi_step implementation can be found here: https://github.com/shyamsn97/jax-nca/blob/main/jax_nca/nca.py#L103
|
67 |
|
|
|
68 |
|
69 |
## Usage
|
70 |
See [notebooks/Gecko.ipynb](notebooks/Gecko.ipynb) for a full example
|
71 |
|
72 |
<b> Currently there's a bug with the stochastic update, so only `cell_fire_rate = 1.0` works at the moment </b>
|
73 |
|
74 |
+
Creating and using NCA
|
75 |
|
76 |
```python
|
77 |
class NCA(nn.Module):
|
|
|
115 |
final_state, nca_states = nca.multi_step(poarams, nca_seed, jax.random.PRNGKey(10), num_steps=32)
|
116 |
```
|
117 |
|
118 |
+
To train the NCA
|
119 |
+
|
120 |
```python
|
121 |
from jax_nca.dataset import ImageDataset
|
122 |
from jax_nca.trainer import EmojiTrainer
|