shyamsn97 commited on
Commit
f595bfc
1 Parent(s): 434b57f

update readme

Browse files
Files changed (1) hide show
  1. README.md +14 -15
README.md CHANGED
@@ -1,32 +1,31 @@
1
- # Neural Cellular Automata (Based on https://distill.pub/2020/growing-ca/) implemented in Jax (Flax)
2
-
3
- ![Gecko gif](https://raw.githubusercontent.com/shyamsn97/jax-nca/main/images/gecko.gif?token=GHSAT0AAAAAABTB4G7FLAJSLDHSIOQONS3IYTB5ZEA)
4
-
5
  ---
6
 
 
 
7
 
8
  ## Installation
9
- from source:
10
- ```
 
11
  git clone git@github.com:shyamsn97/jax-nca.git
12
  cd jax-nca
13
  python setup.py install
14
  ```
15
 
16
  from PYPI
17
- ```
 
18
  pip install jax-nca
19
  ```
20
- ---
21
 
22
  ## How do NCAs work?
23
  For more information, view the awesome article https://distill.pub/2020/growing-ca/ -- Mordvintsev, et al., "Growing Neural Cellular Automata", Distill, 2020
24
 
25
  Image below describes a single update step: https://github.com/distillpub/post--growing-ca/blob/master/public/figures/model.svg
26
 
27
- ![NCA update](https://raw.githubusercontent.com/shyamsn97/jax-nca/main/images/model.svg?token=GHSAT0AAAAAABTB4G7FOWOPXEUYVLBGRNSWYTB5YUA)
28
-
29
- ---
30
 
31
  ## Why Jax?
32
 
@@ -34,7 +33,7 @@ Image below describes a single update step: https://github.com/distillpub/post--
34
 
35
  NCAs are autoregressive models like RNNs, where new states are calculated from previous ones. With jax, we can make these operations a lot more performant with `jax.lax.scan` and `jax.jit` (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html)
36
 
37
- Instead of writing the nca growth process as:
38
 
39
  ```python
40
  def multi_step(params, nca, current_state, num_steps):
@@ -66,14 +65,13 @@ def multi_step(params, nca, current_state, num_steps):
66
  ```
67
  The actual multi_step implementation can be found here: https://github.com/shyamsn97/jax-nca/blob/main/jax_nca/nca.py#L103
68
 
69
- ---
70
 
71
  ## Usage
72
  See [notebooks/Gecko.ipynb](notebooks/Gecko.ipynb) for a full example
73
 
74
  <b> Currently there's a bug with the stochastic update, so only `cell_fire_rate = 1.0` works at the moment </b>
75
 
76
- Creating and using NCA:
77
 
78
  ```python
79
  class NCA(nn.Module):
@@ -117,7 +115,8 @@ update = nca.apply({"params":params}, nca_seed, jax.random.PRNGKey(10))
117
  final_state, nca_states = nca.multi_step(poarams, nca_seed, jax.random.PRNGKey(10), num_steps=32)
118
  ```
119
 
120
- To train the NCA:
 
121
  ```python
122
  from jax_nca.dataset import ImageDataset
123
  from jax_nca.trainer import EmojiTrainer
1
+ ---
2
+ tags:
3
+ - image-generation
 
4
  ---
5
 
6
+ # Neural Cellular Automata (Based on https://distill.pub/2020/growing-ca/) implemented in Jax (Flax)
7
+
8
 
9
  ## Installation
10
+ from source
11
+
12
+ ```bash
13
  git clone git@github.com:shyamsn97/jax-nca.git
14
  cd jax-nca
15
  python setup.py install
16
  ```
17
 
18
  from PYPI
19
+
20
+ ```bash
21
  pip install jax-nca
22
  ```
 
23
 
24
  ## How do NCAs work?
25
  For more information, view the awesome article https://distill.pub/2020/growing-ca/ -- Mordvintsev, et al., "Growing Neural Cellular Automata", Distill, 2020
26
 
27
  Image below describes a single update step: https://github.com/distillpub/post--growing-ca/blob/master/public/figures/model.svg
28
 
 
 
 
29
 
30
  ## Why Jax?
31
 
33
 
34
  NCAs are autoregressive models like RNNs, where new states are calculated from previous ones. With jax, we can make these operations a lot more performant with `jax.lax.scan` and `jax.jit` (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html)
35
 
36
+ Instead of writing the nca growth process as
37
 
38
  ```python
39
  def multi_step(params, nca, current_state, num_steps):
65
  ```
66
  The actual multi_step implementation can be found here: https://github.com/shyamsn97/jax-nca/blob/main/jax_nca/nca.py#L103
67
 
 
68
 
69
  ## Usage
70
  See [notebooks/Gecko.ipynb](notebooks/Gecko.ipynb) for a full example
71
 
72
  <b> Currently there's a bug with the stochastic update, so only `cell_fire_rate = 1.0` works at the moment </b>
73
 
74
+ Creating and using NCA
75
 
76
  ```python
77
  class NCA(nn.Module):
115
  final_state, nca_states = nca.multi_step(poarams, nca_seed, jax.random.PRNGKey(10), num_steps=32)
116
  ```
117
 
118
+ To train the NCA
119
+
120
  ```python
121
  from jax_nca.dataset import ImageDataset
122
  from jax_nca.trainer import EmojiTrainer