Jax-NCA / README.md
shyamsn97
update readme
f595bfc
|
raw
history blame
No virus
4.12 kB
---
tags:
- image-generation
---
# Neural Cellular Automata (Based on https://distill.pub/2020/growing-ca/) implemented in Jax (Flax)
## Installation
from source
```bash
git clone git@github.com:shyamsn97/jax-nca.git
cd jax-nca
python setup.py install
```
from PYPI
```bash
pip install jax-nca
```
## How do NCAs work?
For more information, view the awesome article https://distill.pub/2020/growing-ca/ -- Mordvintsev, et al., "Growing Neural Cellular Automata", Distill, 2020
Image below describes a single update step: https://github.com/distillpub/post--growing-ca/blob/master/public/figures/model.svg
## Why Jax?
<b> Note: This project served as a nice introduction to jax, so its performance can probably be improved </b>
NCAs are autoregressive models like RNNs, where new states are calculated from previous ones. With jax, we can make these operations a lot more performant with `jax.lax.scan` and `jax.jit` (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html)
Instead of writing the nca growth process as
```python
def multi_step(params, nca, current_state, num_steps):
# params: parameters for NCA
# nca: Flax Module describing NCA
# current_state: Current NCA state
# num_steps: number of steps to run
for i in range(num_steps):
current_state = nca.apply(params, current_state)
return current_state
```
We can write this with `jax.lax.scan`
```python
def multi_step(params, nca, current_state, num_steps):
# params: parameters for NCA
# nca: Flax Module describing NCA
# current_state: Current NCA state
# num_steps: number of steps to run
def forward(carry, inp):
carry = nca.apply({"params": params}, carry)
return carry, carry
final_state, nca_states = jax.lax.scan(forward, current_state, None, length=num_steps)
return final_state
```
The actual multi_step implementation can be found here: https://github.com/shyamsn97/jax-nca/blob/main/jax_nca/nca.py#L103
## Usage
See [notebooks/Gecko.ipynb](notebooks/Gecko.ipynb) for a full example
<b> Currently there's a bug with the stochastic update, so only `cell_fire_rate = 1.0` works at the moment </b>
Creating and using NCA
```python
class NCA(nn.Module):
num_hidden_channels: int
num_target_channels: int = 3
alpha_living_threshold: float = 0.1
cell_fire_rate: float = 1.0
trainable_perception: bool = False
alpha: float = 1.0
"""
num_hidden_channels: Number of hidden channels for each cell to use
num_target_channels: Number of target channels to be used
alpha_living_threshold: threshold to determine whether a cell lives or dies
cell_fire_rate: probability that a cell receives an update per step
trainable_perception: if true, instead of using sobel filters use a trainable conv net
alpha: scalar value to be multiplied to updates
"""
...
from jax_nca.nca import NCA
# usage
nca = NCA(
num_hidden_channels = 16,
num_target_channels = 3,
trainable_perception = False,
cell_fire_rate = 1.0,
alpha_living_threshold = 0.1
)
nca_seed = nca.create_seed(
nca.num_hidden_channels, nca.num_target_channels, shape=(64,64), batch_size=1
)
rng = jax.random.PRNGKey(0)
params = = nca.init(rng, nca_seed, rng)["params"]
update = nca.apply({"params":params}, nca_seed, jax.random.PRNGKey(10))
# multi step
final_state, nca_states = nca.multi_step(poarams, nca_seed, jax.random.PRNGKey(10), num_steps=32)
```
To train the NCA
```python
from jax_nca.dataset import ImageDataset
from jax_nca.trainer import EmojiTrainer
dataset = ImageDataset(emoji='🦎', img_size=64)
nca = NCA(
num_hidden_channels = 16,
num_target_channels = 3,
trainable_perception = False,
cell_fire_rate = 1.0,
alpha_living_threshold = 0.1
)
trainer = EmojiTrainer(dataset, nca, n_damage=0)
trainer.train(100000, batch_size=8, seed=10, lr=2e-4, min_steps=64, max_steps=96)
# to access train state:
state = trainer.state
# save
nca.save(state.params, "saved_params")
# load params
loaded_params = nca.load("saved_params")
```