violet1723 commited on
Commit
00c2650
·
verified ·
1 Parent(s): 8a0e6a7

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -1,35 +1,11 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ # Auto detect text files and perform LF normalization
2
+ * text=auto
3
+ # SCM syntax highlighting & preventing 3-way merges
4
+ pixi.lock merge=binary linguist-language=YAML linguist-generated=true
5
+ docs/LSTM_hexagons.png filter=lfs diff=lfs merge=lfs -text
6
+ docs/RNNgrids.png filter=lfs diff=lfs merge=lfs -text
7
+ docs/poisson_spiking.gif filter=lfs diff=lfs merge=lfs -text
8
+ models/example_pc_centers.npy filter=lfs diff=lfs merge=lfs -text
9
+ models/example_trained_weights.npy filter=lfs diff=lfs merge=lfs -text
10
+ models/steps_20_batch_200_RNN_4096_relu_rf_012_DoG_True_periodic_False_lr_00001_weight_decay_00001/1.png filter=lfs diff=lfs merge=lfs -text
11
+ models/steps_20_batch_200_RNN_4096_relu_rf_012_DoG_True_periodic_False_lr_00001_weight_decay_00001/most_recent_model.pth filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ images/
7
+ movies/
8
+ data/
9
+ fixed_point_finder_data/
10
+ ben_decimate_1000_step/
11
+ ben_true_100_step/
12
+ ben_100*/
13
+ data/
14
+ nohup.out
15
+
16
+ # C extensions
17
+ *.so
18
+
19
+ # Distribution / packaging
20
+ .Python
21
+ build/
22
+ develop-eggs/
23
+ dist/
24
+ downloads/
25
+ eggs/
26
+ .eggs/
27
+ lib/
28
+ lib64/
29
+ parts/
30
+ sdist/
31
+ var/
32
+ wheels/
33
+ *.egg-info/
34
+ .installed.cfg
35
+ *.egg
36
+ MANIFEST
37
+
38
+ # PyInstaller
39
+ # Usually these files are written by a python script from a template
40
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
41
+ *.manifest
42
+ *.spec
43
+
44
+ # Installer logs
45
+ pip-log.txt
46
+ pip-delete-this-directory.txt
47
+
48
+ # Unit test / coverage reports
49
+ htmlcov/
50
+ .tox/
51
+ .coverage
52
+ .coverage.*
53
+ .cache
54
+ nosetests.xml
55
+ coverage.xml
56
+ *.cover
57
+ .hypothesis/
58
+ .pytest_cache/
59
+
60
+ # Translations
61
+ *.mo
62
+ *.pot
63
+
64
+ # Django stuff:
65
+ *.log
66
+ local_settings.py
67
+ db.sqlite3
68
+
69
+ # Flask stuff:
70
+ instance/
71
+ .webassets-cache
72
+
73
+ # Scrapy stuff:
74
+ .scrapy
75
+
76
+ # Sphinx documentation
77
+ docs/_build/
78
+
79
+ # PyBuilder
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # pyenv
86
+ .python-version
87
+
88
+ # celery beat schedule file
89
+ celerybeat-schedule
90
+
91
+ # SageMath parsed files
92
+ *.sage.py
93
+
94
+ # Environments
95
+ .env
96
+ .venv
97
+ env/
98
+ venv/
99
+ ENV/
100
+ env.bak/
101
+ venv.bak/
102
+
103
+ # Spyder project settings
104
+ .spyderproject
105
+ .spyproject
106
+
107
+ # Rope project settings
108
+ .ropeproject
109
+
110
+ # mkdocs documentation
111
+ /site
112
+
113
+ # mypy
114
+ .mypy_cache/
115
+
116
+
117
+ # by user
118
+ saved
119
+ *.ipynb
120
+ *.zip
121
+ # pixi environments
122
+ .pixi/*
123
+ !.pixi/config.toml
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,3 +1,69 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [![DOI](https://zenodo.org/badge/217773694.svg)](https://zenodo.org/badge/latestdoi/217773694)
2
+
3
+ # Grid cells in RNNs trained to path integrate
4
+
5
+ Code to reproduce the trained RNN in [**a unified theory for the origin of grid cells through the lens of pattern formation (NeurIPS '19)**](https://papers.nips.cc/paper/9191-a-unified-theory-for-the-origin-of-grid-cells-through-the-lens-of-pattern-formation) and additional analysis described in this [**preprint**](https://www.biorxiv.org/content/10.1101/2020.12.29.424583v1).
6
+
7
+
8
+ Quick start:
9
+
10
+ <img src="./docs/poisson_spiking.gif" width="300" align="right">
11
+
12
+ * [**inspect_model.ipynb**](inspect_model.ipynb):
13
+ Train a model and visualize its hidden unit ratemaps.
14
+
15
+ * [**main.py**](main.py):
16
+ or, train a model from the command line.
17
+
18
+ * [**pattern_formation.ipynb**](pattern_formation.ipynb):
19
+ Numerical simulations of pattern-forming dynamics.
20
+
21
+
22
+ Includes:
23
+
24
+ * [**trajectory_generator.py**](trajectory_generator.py):
25
+ Generate simulated rat trajectories in a rectangular environment.
26
+
27
+ * [**place_cells.py**](place_cells.py):
28
+ Tile a set of simulated place cells across the training environment.
29
+
30
+ * [**model.py**](model.py):
31
+ Contains the vanilla RNN model architecture, as well as an LSTM.
32
+
33
+ * [**trainer.py**](model.py):
34
+ Contains model training loop.
35
+
36
+ * [**models/example_trained_weights.npy**](models/example_trained_weights.npy)
37
+ Contains a set of pre-trained weights.
38
+
39
+ ## Running
40
+
41
+ We recommend creating a virtual environment:
42
+
43
+ ```shell
44
+ $ virtualenv env
45
+ $ source env/bin/activate
46
+ $ pip install --upgrade pip
47
+ ```
48
+
49
+ Then, install the dependencies automatically with `pip install -r requirements.txt`
50
+ or manually with:
51
+
52
+ ```shell
53
+ $ pip install --upgrade numpy==1.17.2
54
+ $ pip install --upgrade tensorflow==2.0.0rc2
55
+ $ pip install --upgrade scipy==1.4.1
56
+ $ pip install --upgrade matplotlib==3.0.3
57
+ $ pip install --upgrade imageio==2.5.0
58
+ $ pip install --upgrade opencv-python==4.1.1.26
59
+ $ pip install --upgrade tqdm==4.36.0
60
+ $ pip install --upgrade opencv-python==4.1.1.26
61
+ $ pip install --upgrade torch==1.10.0
62
+ ```
63
+
64
+ If you want to train your own models, make sure to properly set the default
65
+ save directory in `main.py`!
66
+
67
+ ## Result
68
+
69
+ ![grid visualization](./docs/RNNgrids.png)
docs/LSTM_hexagons.png ADDED

Git LFS Details

  • SHA256: c4aab83e5e7f68f197e9ade73a7e185e9c37b03cf4ce04ce0dc82d33b138ae21
  • Pointer size: 132 Bytes
  • Size of remote file: 1.48 MB
docs/RNNgrids.png ADDED

Git LFS Details

  • SHA256: 9c9cde71769f3a7990115afa8c4b3e14efd3ad549d547346a3fe896ac33e4076
  • Pointer size: 132 Bytes
  • Size of remote file: 2.95 MB
docs/poisson_spiking.gif ADDED

Git LFS Details

  • SHA256: 0016ada81c06ecf3eb22fd60c8aa57380899d77e9c7d19a27a8ac50138df78fd
  • Pointer size: 131 Bytes
  • Size of remote file: 318 kB
main.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ import torch.cuda
4
+ import argparse
5
+
6
+
7
+ from utils import generate_run_ID
8
+ from place_cells import PlaceCells
9
+ from trajectory_generator import TrajectoryGenerator
10
+ from model import RNN
11
+ from trainer import Trainer
12
+
13
+ tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
14
+
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument(
17
+ "--save_dir",
18
+ # default='/mnt/fs2/bsorsch/grid_cells/models/',
19
+ default="models/",
20
+ help="directory to save trained models",
21
+ )
22
+ parser.add_argument(
23
+ "--n_epochs", default=100, help="number of training epochs", type=int
24
+ )
25
+ parser.add_argument("--n_steps", default=1000, help="batches per epoch", type=int)
26
+ parser.add_argument(
27
+ "--batch_size", default=200, help="number of trajectories per batch", type=int
28
+ )
29
+ parser.add_argument(
30
+ "--sequence_length", default=20, help="number of steps in trajectory", type=int
31
+ )
32
+ parser.add_argument(
33
+ "--learning_rate", default=1e-4, help="gradient descent learning rate", type=float
34
+ )
35
+ parser.add_argument("--Np", default=512, help="number of place cells", type=int)
36
+ parser.add_argument("--Ng", default=4096, help="number of grid cells", type=int)
37
+ parser.add_argument(
38
+ "--place_cell_rf",
39
+ default=0.12,
40
+ help="width of place cell center tuning curve (m)",
41
+ type=float,
42
+ )
43
+ parser.add_argument(
44
+ "--surround_scale",
45
+ default=2,
46
+ help="if DoG, ratio of sigma2^2 to sigma1^2",
47
+ type=int,
48
+ )
49
+ parser.add_argument("--RNN_type", default="RNN", help="RNN or LSTM")
50
+ parser.add_argument("--activation", default="relu", help="recurrent nonlinearity")
51
+ parser.add_argument(
52
+ "--weight_decay",
53
+ default=1e-4,
54
+ help="strength of weight decay on recurrent weights",
55
+ type=float,
56
+ )
57
+ parser.add_argument(
58
+ "--DoG", default=True, help="use difference of gaussians tuning curves"
59
+ )
60
+ parser.add_argument(
61
+ "--periodic", default=False, help="trajectories with periodic boundary conditions"
62
+ )
63
+ parser.add_argument(
64
+ "--box_width", default=2.2, help="width of training environment", type=float
65
+ )
66
+ parser.add_argument(
67
+ "--box_height", default=2.2, help="height of training environment", type=float
68
+ )
69
+ parser.add_argument(
70
+ "--device",
71
+ default="cuda" if torch.cuda.is_available() else "cpu",
72
+ help="device to use for training",
73
+ )
74
+ parser.add_argument(
75
+ "--seed", default=None, help="seed number for all numpy random number generator"
76
+ )
77
+
78
+ options = parser.parse_args()
79
+ options.run_ID = generate_run_ID(options)
80
+
81
+ print(f"Using device: {options.device}")
82
+
83
+ if options.seed:
84
+ np.random.seed(int(options.seed))
85
+
86
+ place_cells = PlaceCells(options)
87
+ if options.RNN_type == "RNN":
88
+ model = RNN(options, place_cells)
89
+ elif options.RNN_type == "LSTM":
90
+ # model = LSTM(options, place_cells)
91
+ raise NotImplementedError
92
+
93
+ # Put model on GPU if using GPU
94
+ if options.device == "cuda":
95
+ print("Using CUDA")
96
+ model = model.to(options.device)
97
+
98
+ trajectory_generator = TrajectoryGenerator(options, place_cells)
99
+
100
+ trainer = Trainer(options, model, trajectory_generator)
101
+
102
+ # Train
103
+ trainer.train(n_epochs=options.n_epochs, n_steps=options.n_steps)
model.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import torch
3
+
4
+
5
+ class RNN(torch.nn.Module):
6
+ def __init__(self, options, place_cells):
7
+ super(RNN, self).__init__()
8
+ self.Ng = options.Ng
9
+ self.Np = options.Np
10
+ self.sequence_length = options.sequence_length
11
+ self.weight_decay = options.weight_decay
12
+ self.place_cells = place_cells
13
+
14
+ # Input weights
15
+ self.encoder = torch.nn.Linear(self.Np, self.Ng, bias=False)
16
+ self.RNN = torch.nn.RNN(
17
+ input_size=2,
18
+ hidden_size=self.Ng,
19
+ nonlinearity=options.activation,
20
+ bias=False,
21
+ )
22
+ # Linear read-out weights
23
+ self.decoder = torch.nn.Linear(self.Ng, self.Np, bias=False)
24
+
25
+ self.softmax = torch.nn.Softmax(dim=-1)
26
+
27
+ def g(self, inputs):
28
+ """
29
+ Compute grid cell activations.
30
+ Args:
31
+ inputs: Batch of 2d velocity inputs with shape [batch_size, sequence_length, 2].
32
+
33
+ Returns:
34
+ g: Batch of grid cell activations with shape [batch_size, sequence_length, Ng].
35
+ """
36
+ v, p0 = inputs
37
+ init_state = self.encoder(p0)[None]
38
+ g, _ = self.RNN(v, init_state)
39
+ return g
40
+
41
+ def predict(self, inputs):
42
+ """
43
+ Predict place cell code.
44
+ Args:
45
+ inputs: Batch of 2d velocity inputs with shape [batch_size, sequence_length, 2].
46
+
47
+ Returns:
48
+ place_preds: Predicted place cell activations with shape
49
+ [batch_size, sequence_length, Np].
50
+ """
51
+ place_preds = self.decoder(self.g(inputs))
52
+
53
+ return place_preds
54
+
55
+ def set_weights(self, weights):
56
+ """
57
+ Load weights from a numpy array (e.g. from the provided example weights).
58
+ Assumes weights are in the order: [encoder, rnn_ih, rnn_hh, decoder]
59
+ and transposed (TF/Keras format).
60
+ """
61
+ with torch.no_grad():
62
+ # Encoder: (Np, Ng) -> (Ng, Np)
63
+ self.encoder.weight.copy_(torch.from_numpy(weights[0].T).float())
64
+
65
+ # RNN input: (2, Ng) -> (Ng, 2)
66
+ self.RNN.weight_ih_l0.copy_(torch.from_numpy(weights[1].T).float())
67
+
68
+ # RNN hidden: (Ng, Ng) -> (Ng, Ng)
69
+ self.RNN.weight_hh_l0.copy_(torch.from_numpy(weights[2].T).float())
70
+
71
+ # Decoder: (Ng, Np) -> (Np, Ng)
72
+ self.decoder.weight.copy_(torch.from_numpy(weights[3].T).float())
73
+
74
+ def compute_loss(self, inputs, pc_outputs, pos):
75
+ """
76
+ Compute avg. loss and decoding error.
77
+ Args:
78
+ inputs: Batch of 2d velocity inputs with shape [batch_size, sequence_length, 2].
79
+ pc_outputs: Ground truth place cell activations with shape
80
+ [batch_size, sequence_length, Np].
81
+ pos: Ground truth 2d position with shape [batch_size, sequence_length, 2].
82
+
83
+ Returns:
84
+ loss: Avg. loss for this training batch.
85
+ err: Avg. decoded position error in cm.
86
+ """
87
+ y: torch.Tensor = pc_outputs
88
+ preds: torch.Tensor = self.predict(inputs)
89
+ loss = torch.nn.functional.cross_entropy(preds.flatten(0, 1), y.flatten(0, 1))
90
+
91
+ # Weight regularization
92
+ loss += self.weight_decay * (self.RNN.weight_hh_l0**2).sum()
93
+
94
+ # Compute decoding error
95
+ pred_pos = self.place_cells.get_nearest_cell_pos(preds)
96
+ err = torch.sqrt(((pos - pred_pos) ** 2).sum(-1)).mean()
97
+
98
+ return loss, err
models/example_pc_centers.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cafa64a4894d5f1c4aaf4befdf036fb43e9c9ed2f2825b79feb02ad1f7199997
3
+ size 4224
models/example_trained_weights.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5277a433615f00894ffe13d55b46d5becda9b362e099c7c73a6191e61bf51ff2
3
+ size 83919352
models/steps_20_batch_200_RNN_4096_relu_rf_012_DoG_True_periodic_False_lr_00001_weight_decay_00001/1.png ADDED

Git LFS Details

  • SHA256: 96c89186cd12b9d006b31ac7ce48ca775d91d3cece5d9655bad10cbfe3e2be3d
  • Pointer size: 131 Bytes
  • Size of remote file: 365 kB
models/steps_20_batch_200_RNN_4096_relu_rf_012_DoG_True_periodic_False_lr_00001_weight_decay_00001/most_recent_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f7050cd3c660329c5c49e1bc935e66cb9eaebeb42caff69930d6ee793c8f6c1
3
+ size 83954107
place_cells.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import numpy as np
3
+ import torch
4
+ import scipy
5
+
6
+
7
+ class PlaceCells(object):
8
+ def __init__(self, options, us=None):
9
+ self.Np = options.Np
10
+ self.sigma = options.place_cell_rf
11
+ self.surround_scale = options.surround_scale
12
+ self.box_width = options.box_width
13
+ self.box_height = options.box_height
14
+ self.is_periodic = options.periodic
15
+ self.DoG = options.DoG
16
+ self.device = options.device
17
+ self.softmax = torch.nn.Softmax(dim=-1)
18
+
19
+ # Randomly tile place cell centers across environment
20
+ usx = np.random.uniform(-self.box_width / 2, self.box_width / 2, (self.Np,))
21
+ usy = np.random.uniform(-self.box_width / 2, self.box_width / 2, (self.Np,))
22
+ self.us = torch.tensor(np.vstack([usx, usy]).T)
23
+ # If using a GPU, put on GPU
24
+ self.us = self.us.to(self.device)
25
+ # self.us = torch.tensor(np.load('models/example_pc_centers.npy')).cuda()
26
+
27
+ def get_activation(self, pos):
28
+ """
29
+ Get place cell activations for a given position.
30
+
31
+ Args:
32
+ pos: 2d position of shape [batch_size, sequence_length, 2].
33
+
34
+ Returns:
35
+ outputs: Place cell activations with shape [batch_size, sequence_length, Np].
36
+ """
37
+ d = torch.abs(pos[:, :, None, :] - self.us[None, None, ...]).float()
38
+
39
+ if self.is_periodic:
40
+ dx = d[:, :, :, 0]
41
+ dy = d[:, :, :, 1]
42
+ dx = torch.minimum(dx, self.box_width - dx)
43
+ dy = torch.minimum(dy, self.box_height - dy)
44
+ d = torch.stack([dx, dy], axis=-1)
45
+
46
+ norm2 = (d**2).sum(-1)
47
+
48
+ # Normalize place cell outputs with prefactor alpha=1/2/np.pi/self.sigma**2,
49
+ # or, simply normalize with softmax, which yields same normalization on
50
+ # average and seems to speed up training.
51
+ outputs = self.softmax(-norm2 / (2 * self.sigma**2))
52
+
53
+ if self.DoG:
54
+ # Again, normalize with prefactor
55
+ # beta=1/2/np.pi/self.sigma**2/self.surround_scale, or use softmax.
56
+ outputs -= self.softmax(-norm2 / (2 * self.surround_scale * self.sigma**2))
57
+
58
+ # Shift and scale outputs so that they lie in [0,1].
59
+ min_output, _ = outputs.min(-1, keepdims=True)
60
+ outputs += torch.abs(min_output)
61
+ outputs /= outputs.sum(-1, keepdims=True)
62
+ return outputs
63
+
64
+ def get_nearest_cell_pos(self, activation, k=3):
65
+ """
66
+ Decode position using centers of k maximally active place cells.
67
+
68
+ Args:
69
+ activation: Place cell activations of shape [batch_size, sequence_length, Np].
70
+ k: Number of maximally active place cells with which to decode position.
71
+
72
+ Returns:
73
+ pred_pos: Predicted 2d position with shape [batch_size, sequence_length, 2].
74
+ """
75
+ _, idxs = torch.topk(activation, k=k)
76
+ pred_pos = self.us[idxs].mean(-2)
77
+ return pred_pos
78
+
79
+ def grid_pc(self, pc_outputs, res=32):
80
+ """Interpolate place cell outputs onto a grid"""
81
+ coordsx = np.linspace(-self.box_width / 2, self.box_width / 2, res)
82
+ coordsy = np.linspace(-self.box_height / 2, self.box_height / 2, res)
83
+ grid_x, grid_y = np.meshgrid(coordsx, coordsy)
84
+ grid = np.stack([grid_x.ravel(), grid_y.ravel()]).T
85
+
86
+ # Convert to numpy
87
+ pc_outputs = pc_outputs.reshape(-1, self.Np)
88
+
89
+ T = pc_outputs.shape[0] # T vs transpose? What is T? (dim's?)
90
+ pc = np.zeros([T, res, res])
91
+ for i in range(len(pc_outputs)):
92
+ gridval = scipy.interpolate.griddata(self.us.cpu(), pc_outputs[i], grid)
93
+ pc[i] = gridval.reshape([res, res])
94
+
95
+ return pc
96
+
97
+ def compute_covariance(self, res=30):
98
+ """Compute spatial covariance matrix of place cell outputs"""
99
+ pos = np.array(
100
+ np.meshgrid(
101
+ np.linspace(-self.box_width / 2, self.box_width / 2, res),
102
+ np.linspace(-self.box_height / 2, self.box_height / 2, res),
103
+ )
104
+ ).T
105
+
106
+ pos = torch.tensor(pos)
107
+
108
+ # Put on GPU if available
109
+ pos = pos.to(self.device)
110
+
111
+ # Maybe specify dimensions here again?
112
+ pc_outputs = self.get_activation(pos).reshape(-1, self.Np).cpu()
113
+
114
+ C = pc_outputs @ pc_outputs.T
115
+ Csquare = C.reshape(res, res, res, res)
116
+
117
+ Cmean = np.zeros([res, res])
118
+ for i in range(res):
119
+ for j in range(res):
120
+ Cmean += np.roll(np.roll(Csquare[i, j], -i, axis=0), -j, axis=1)
121
+
122
+ Cmean = np.roll(np.roll(Cmean, res // 2, axis=0), res // 2, axis=1)
123
+
124
+ return Cmean
pyproject.toml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ authors = [{name = "Mio", email = "violet172309@gmail.com"}]
3
+ dependencies = ["numpy", "tensorflow", "scipy", "matplotlib", "imageio", "opencv-python", "tqdm", "torch"]
4
+ name = "grid-pattern-formation"
5
+ requires-python = ">= 3.11"
6
+ version = "0.1.0"
7
+
8
+ [build-system]
9
+ build-backend = "hatchling.build"
10
+ requires = ["hatchling"]
11
+
12
+ [tool.pixi.workspace]
13
+ channels = ["conda-forge"]
14
+ platforms = ["linux-64"]
15
+
16
+ [tool.pixi.system-requirements]
17
+ cuda = "12.6"
18
+
19
+ [tool.pixi.pypi-dependencies]
20
+ grid_pattern_formation = { path = ".", editable = true }
21
+
22
+ [tool.pixi.tasks]
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ tensorflow
3
+ scipy
4
+ matplotlib
5
+ imageio
6
+ opencv-python
7
+ tqdm
8
+ torch
scores.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2018 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Grid score calculations.
17
+ """
18
+
19
+ from __future__ import absolute_import
20
+ from __future__ import division
21
+ from __future__ import print_function
22
+
23
+ import math
24
+ import matplotlib.pyplot as plt
25
+ import numpy as np
26
+ import scipy.signal
27
+ import scipy.ndimage as ndimage
28
+
29
+
30
+ def circle_mask(size, radius, in_val=1.0, out_val=0.0):
31
+ """Calculating the grid scores with different radius."""
32
+ sz = [math.floor(size[0] / 2), math.floor(size[1] / 2)]
33
+ x = np.linspace(-sz[0], sz[1], size[1])
34
+ x = np.expand_dims(x, 0)
35
+ x = x.repeat(size[0], 0)
36
+ y = np.linspace(-sz[0], sz[1], size[1])
37
+ y = np.expand_dims(y, 1)
38
+ y = y.repeat(size[1], 1)
39
+ z = np.sqrt(x**2 + y**2)
40
+ z = np.less_equal(z, radius)
41
+ vfunc = np.vectorize(lambda b: b and in_val or out_val)
42
+ return vfunc(z)
43
+
44
+
45
+ class GridScorer(object):
46
+ """Class for scoring ratemaps given trajectories."""
47
+
48
+ def __init__(self, nbins, coords_range, mask_parameters, min_max=False):
49
+ """Scoring ratemaps given trajectories.
50
+ Args:
51
+ nbins: Number of bins per dimension in the ratemap.
52
+ coords_range: Environment coordinates range.
53
+ mask_parameters: parameters for the masks that analyze the angular
54
+ autocorrelation of the 2D autocorrelation.
55
+ min_max: Correction.
56
+ """
57
+ self._nbins = nbins
58
+ self._min_max = min_max
59
+ self._coords_range = coords_range
60
+ self._corr_angles = [30, 45, 60, 90, 120, 135, 150]
61
+ # Create all masks
62
+ self._masks = [(self._get_ring_mask(mask_min, mask_max), (mask_min,
63
+ mask_max))
64
+ for mask_min, mask_max in mask_parameters]
65
+ # Mask for hiding the parts of the SAC that are never used
66
+ self._plotting_sac_mask = circle_mask(
67
+ [self._nbins * 2 - 1, self._nbins * 2 - 1],
68
+ self._nbins,
69
+ in_val=1.0,
70
+ out_val=np.nan)
71
+
72
+ def calculate_ratemap(self, xs, ys, activations, statistic='mean'):
73
+ return scipy.stats.binned_statistic_2d(
74
+ xs,
75
+ ys,
76
+ activations,
77
+ bins=self._nbins,
78
+ statistic=statistic,
79
+ range=self._coords_range)[0]
80
+
81
+ def _get_ring_mask(self, mask_min, mask_max):
82
+ n_points = [self._nbins * 2 - 1, self._nbins * 2 - 1]
83
+ return (circle_mask(n_points, mask_max * self._nbins) *
84
+ (1 - circle_mask(n_points, mask_min * self._nbins)))
85
+
86
+ def grid_score_60(self, corr):
87
+ if self._min_max:
88
+ return np.minimum(corr[60], corr[120]) - np.maximum(
89
+ corr[30], np.maximum(corr[90], corr[150]))
90
+ else:
91
+ return (corr[60] + corr[120]) / 2 - (corr[30] + corr[90] + corr[150]) / 3
92
+
93
+ def grid_score_90(self, corr):
94
+ return corr[90] - (corr[45] + corr[135]) / 2
95
+
96
+ def calculate_sac(self, seq1):
97
+ """Calculating spatial autocorrelogram."""
98
+ seq2 = seq1
99
+
100
+ def filter2(b, x):
101
+ stencil = np.rot90(b, 2)
102
+ return scipy.signal.convolve2d(x, stencil, mode='full')
103
+
104
+ seq1 = np.nan_to_num(seq1)
105
+ seq2 = np.nan_to_num(seq2)
106
+
107
+ ones_seq1 = np.ones(seq1.shape)
108
+ ones_seq1[np.isnan(seq1)] = 0
109
+ ones_seq2 = np.ones(seq2.shape)
110
+ ones_seq2[np.isnan(seq2)] = 0
111
+
112
+ seq1[np.isnan(seq1)] = 0
113
+ seq2[np.isnan(seq2)] = 0
114
+
115
+ seq1_sq = np.square(seq1)
116
+ seq2_sq = np.square(seq2)
117
+
118
+ seq1_x_seq2 = filter2(seq1, seq2)
119
+ sum_seq1 = filter2(seq1, ones_seq2)
120
+ sum_seq2 = filter2(ones_seq1, seq2)
121
+ sum_seq1_sq = filter2(seq1_sq, ones_seq2)
122
+ sum_seq2_sq = filter2(ones_seq1, seq2_sq)
123
+ n_bins = filter2(ones_seq1, ones_seq2)
124
+ n_bins_sq = np.square(n_bins)
125
+
126
+ std_seq1 = np.power(
127
+ np.subtract(
128
+ np.divide(sum_seq1_sq, n_bins),
129
+ (np.divide(np.square(sum_seq1), n_bins_sq))), 0.5)
130
+ std_seq2 = np.power(
131
+ np.subtract(
132
+ np.divide(sum_seq2_sq, n_bins),
133
+ (np.divide(np.square(sum_seq2), n_bins_sq))), 0.5)
134
+ covar = np.subtract(
135
+ np.divide(seq1_x_seq2, n_bins),
136
+ np.divide(np.multiply(sum_seq1, sum_seq2), n_bins_sq))
137
+ x_coef = np.divide(covar, np.multiply(std_seq1, std_seq2))
138
+ x_coef = np.real(x_coef)
139
+ x_coef = np.nan_to_num(x_coef)
140
+ return x_coef
141
+
142
+ def rotated_sacs(self, sac, angles):
143
+ return [
144
+ scipy.ndimage.interpolation.rotate(sac, angle, reshape=False)
145
+ for angle in angles
146
+ ]
147
+
148
+ def get_grid_scores_for_mask(self, sac, rotated_sacs, mask):
149
+ """Calculate Pearson correlations of area inside mask at corr_angles."""
150
+ masked_sac = sac * mask
151
+ ring_area = np.sum(mask)
152
+ # Calculate dc on the ring area
153
+ masked_sac_mean = np.sum(masked_sac) / ring_area
154
+ # Center the sac values inside the ring
155
+ masked_sac_centered = (masked_sac - masked_sac_mean) * mask
156
+ variance = np.sum(masked_sac_centered**2) / ring_area + 1e-5
157
+ corrs = dict()
158
+ for angle, rotated_sac in zip(self._corr_angles, rotated_sacs):
159
+ masked_rotated_sac = (rotated_sac - masked_sac_mean) * mask
160
+ cross_prod = np.sum(masked_sac_centered * masked_rotated_sac) / ring_area
161
+ corrs[angle] = cross_prod / variance
162
+ return self.grid_score_60(corrs), self.grid_score_90(corrs), variance
163
+
164
+ def get_scores(self, rate_map):
165
+ """Get summary of scrores for grid cells."""
166
+ sac = self.calculate_sac(rate_map)
167
+ rotated_sacs = self.rotated_sacs(sac, self._corr_angles)
168
+
169
+ scores = [
170
+ self.get_grid_scores_for_mask(sac, rotated_sacs, mask)
171
+ for mask, mask_params in self._masks # pylint: disable=unused-variable
172
+ ]
173
+ scores_60, scores_90, variances = map(np.asarray, zip(*scores)) # pylint: disable=unused-variable
174
+ max_60_ind = np.argmax(scores_60)
175
+ max_90_ind = np.argmax(scores_90)
176
+
177
+ return (scores_60[max_60_ind], scores_90[max_90_ind],
178
+ self._masks[max_60_ind][1], self._masks[max_90_ind][1], sac, max_60_ind)
179
+
180
+ def plot_ratemap(self, ratemap, ax=None, title=None, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg
181
+ """Plot ratemaps."""
182
+ if ax is None:
183
+ ax = plt.gca()
184
+ # Plot the ratemap
185
+ ax.imshow(ratemap, interpolation='none', *args, **kwargs)
186
+ # ax.pcolormesh(ratemap, *args, **kwargs)
187
+ ax.axis('off')
188
+ if title is not None:
189
+ ax.set_title(title)
190
+
191
+ def plot_sac(self,
192
+ sac,
193
+ mask_params=None,
194
+ ax=None,
195
+ title=None,
196
+ *args,
197
+ **kwargs): # pylint: disable=keyword-arg-before-vararg
198
+ """Plot spatial autocorrelogram."""
199
+ if ax is None:
200
+ ax = plt.gca()
201
+ # Plot the sac
202
+ useful_sac = sac * self._plotting_sac_mask
203
+ ax.imshow(useful_sac, interpolation='none', *args, **kwargs)
204
+ # ax.pcolormesh(useful_sac, *args, **kwargs)
205
+ # Plot a ring for the adequate mask
206
+ if mask_params is not None:
207
+ center = self._nbins - 1
208
+ ax.add_artist(
209
+ plt.Circle(
210
+ (center, center),
211
+ mask_params[0] * self._nbins,
212
+ # lw=bump_size,
213
+ fill=False,
214
+ edgecolor='k'))
215
+ ax.add_artist(
216
+ plt.Circle(
217
+ (center, center),
218
+ mask_params[1] * self._nbins,
219
+ # lw=bump_size,
220
+ fill=False,
221
+ edgecolor='k'))
222
+ ax.axis('off')
223
+ if title is not None:
224
+ ax.set_title(title)
225
+
226
+
227
+ def border_score(rm, res, box_width):
228
+ # Find connected firing fields
229
+ pix_area = 100**2*box_width**2/res**2
230
+ rm_thresh = rm>(rm.max()*0.3)
231
+ rm_comps, ncomps = ndimage.measurements.label(rm_thresh)
232
+
233
+ # Keep fields with area > 200cm^2
234
+ masks = []
235
+ nfields = 0
236
+ for i in range(1,ncomps+1):
237
+ mask = (rm_comps==i).reshape(res,res)
238
+ if mask.sum()*pix_area > 200:
239
+ masks.append(mask)
240
+ nfields += 1
241
+
242
+ # Max coverage of any one field over any one border
243
+ cm_max = 0
244
+ for mask in masks:
245
+ mask = masks[0]
246
+ n_cov = mask[0].mean()
247
+ s_cov = mask[-1].mean()
248
+ e_cov = mask[:,0].mean()
249
+ w_cov = mask[:,-1].mean()
250
+ cm = np.max([n_cov,s_cov,e_cov,w_cov])
251
+ if cm>cm_max:
252
+ cm_max = cm
253
+
254
+ # Distance to nearest wall
255
+ x,y = np.mgrid[:res,:res] + 1
256
+ x = x.ravel()
257
+ y = y.ravel()
258
+ xmin = np.min(np.vstack([x,res+1-x]),0)
259
+ ymin = np.min(np.vstack([y,res+1-y]),0)
260
+ dweight = np.min(np.vstack([xmin,ymin]),0).reshape(res,res)
261
+ dweight = dweight*box_width/res
262
+
263
+ # Mean firing distance
264
+ dms = []
265
+ for mask in masks:
266
+ field = rm[mask]
267
+ field /= field.sum() # normalize
268
+ dm = (field*dweight[mask]).sum()
269
+ dms.append(dm)
270
+ dm = np.nanmean(dms) / (box_width/2)
271
+ border_score = (cm_max-dm)/(cm_max+dm)
272
+ return border_score, cm_max, dm
src/grid_pattern_formation/__init__.py ADDED
File without changes
trainer.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import torch
3
+ import numpy as np
4
+ import datetime
5
+ from tqdm import tqdm
6
+
7
+ from visualize import save_ratemaps
8
+ import os
9
+
10
+
11
+ class Trainer(object):
12
+ def __init__(self, options, model, trajectory_generator, restore=True):
13
+ self.options = options
14
+ self.model = model
15
+ self.trajectory_generator = trajectory_generator
16
+ lr = self.options.learning_rate
17
+ self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
18
+
19
+ self.loss = []
20
+ self.err = []
21
+
22
+ # Set up checkpoints
23
+ self.ckpt_dir = os.path.join(options.save_dir, options.run_ID)
24
+ ckpt_path = os.path.join(self.ckpt_dir, "most_recent_model.pth")
25
+ pc_path = os.path.join(self.ckpt_dir, "place_cell_centers.npy")
26
+ if restore and os.path.isdir(self.ckpt_dir) and os.path.isfile(ckpt_path):
27
+ self.model.load_state_dict(torch.load(ckpt_path))
28
+ print("Restored trained model from {}".format(ckpt_path))
29
+ # 加载 place cell 位置
30
+ if os.path.isfile(pc_path):
31
+ us = np.load(pc_path)
32
+ self.model.place_cells.us = torch.tensor(us).to(options.device)
33
+ print("Restored place cell centers from {}".format(pc_path))
34
+ else:
35
+ print("Warning: place_cell_centers.npy not found! Model may not work correctly.")
36
+ else:
37
+ if not os.path.isdir(self.ckpt_dir):
38
+ os.makedirs(self.ckpt_dir, exist_ok=True)
39
+ print("Initializing new model from scratch.")
40
+ print("Saving to: {}".format(self.ckpt_dir))
41
+
42
+ def train_step(self, inputs, pc_outputs, pos):
43
+ """
44
+ Train on one batch of trajectories.
45
+
46
+ Args:
47
+ inputs: Batch of 2d velocity inputs with shape [batch_size, sequence_length, 2].
48
+ pc_outputs: Ground truth place cell activations with shape
49
+ [batch_size, sequence_length, Np].
50
+ pos: Ground truth 2d position with shape [batch_size, sequence_length, 2].
51
+
52
+ Returns:
53
+ loss: Avg. loss for this training batch.
54
+ err: Avg. decoded position error in cm.
55
+ """
56
+ self.model.zero_grad()
57
+
58
+ loss, err = self.model.compute_loss(inputs, pc_outputs, pos)
59
+
60
+ loss.backward()
61
+ self.optimizer.step()
62
+
63
+ return loss.item(), err.item()
64
+
65
+ def train(self, n_epochs: int = 1000, n_steps=10, save=True):
66
+ """
67
+ Train model on simulated trajectories.
68
+
69
+ Args:
70
+ n_steps: Number of training steps
71
+ save: If true, save a checkpoint after each epoch.
72
+ """
73
+
74
+ # Construct generator
75
+ gen = self.trajectory_generator.get_generator()
76
+
77
+ for epoch_idx in range(n_epochs):
78
+ tbar = tqdm(range(n_steps), leave=False)
79
+ for step_idx in tbar:
80
+ inputs, pc_outputs, pos = next(gen)
81
+ loss, err = self.train_step(inputs, pc_outputs, pos)
82
+ self.loss.append(loss)
83
+ self.err.append(err)
84
+
85
+ # Log error rate to progress bar
86
+ tbar.set_description('Error = ' + str(int(100*err)) + 'cm')
87
+
88
+ if save and ((epoch_idx + 1) % 10 == 0 or epoch_idx == 0):
89
+ # Save checkpoint
90
+ # ckpt_path = os.path.join(
91
+ # self.ckpt_dir, "epoch_{}.pth".format(epoch_idx)
92
+ # )
93
+ # torch.save(self.model.state_dict(), ckpt_path)
94
+ torch.save(
95
+ self.model.state_dict(),
96
+ os.path.join(self.ckpt_dir, "most_recent_model.pth"),
97
+ )
98
+ # 保存 place cell 位置
99
+ np.save(
100
+ os.path.join(self.ckpt_dir, "place_cell_centers.npy"),
101
+ self.model.place_cells.us.cpu().numpy(),
102
+ )
103
+
104
+ # Save a picture of rate maps
105
+ save_ratemaps(
106
+ self.model,
107
+ self.trajectory_generator,
108
+ self.options,
109
+ step=epoch_idx + 1,
110
+ )
111
+
112
+ print(
113
+ "Epoch: {}/{}. Date: {}. Loss: {}. Err: {}cm".format(
114
+ epoch_idx + 1,
115
+ n_epochs,
116
+ str(datetime.datetime.now())[:-7],
117
+ np.round(loss, 2),
118
+ np.round(100 * err, 2),
119
+ )
120
+ )
trajectory_generator.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import torch
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+
6
+
7
+ class TrajectoryGenerator(object):
8
+ def __init__(self, options, place_cells):
9
+ self.options = options
10
+ self.place_cells = place_cells
11
+
12
+ def plot_trajectory(self, traj, box_width, box_height, idx=0, step=2):
13
+ """
14
+ Visualize one trajectory from traj dict.
15
+
16
+ Args:
17
+ traj: dictionary containing trajectory info
18
+ box_width, box_height: dimensions of the environment
19
+ idx: which trajectory to plot from the batch (default: 0)
20
+ step: plot an arrow every 'step' frames
21
+ """
22
+ # Extract trajectory for one rat
23
+ x = traj["target_x"][idx] # shape (samples,)
24
+ y = traj["target_y"][idx]
25
+ hd = traj["target_hd"][idx] # head directions in radians
26
+
27
+ # Also add starting point
28
+ x0 = traj["init_x"][idx, 0]
29
+ y0 = traj["init_y"][idx, 0]
30
+ hd0 = traj["init_hd"][idx, 0]
31
+
32
+ x = np.concatenate([[x0], x])
33
+ y = np.concatenate([[y0], y])
34
+ hd = np.concatenate([[hd0], hd])
35
+
36
+ # Plot trajectory
37
+ plt.figure(figsize=(6, 6))
38
+ plt.plot(x, y, "-o", markersize=2, label="trajectory")
39
+
40
+ # Add arrows for head direction
41
+ for t in range(0, len(x), step):
42
+ dx = 0.1 * np.cos(hd[t])
43
+ dy = 0.1 * np.sin(hd[t])
44
+ plt.arrow(
45
+ x[t], y[t], dx, dy, head_width=0.05, head_length=0.08, fc="r", ec="r"
46
+ )
47
+
48
+ # Draw box boundaries
49
+ plt.axhline(y=-box_height / 2, color="k")
50
+ plt.axhline(y=box_height / 2, color="k")
51
+ plt.axvline(x=-box_width / 2, color="k")
52
+ plt.axvline(x=box_width / 2, color="k")
53
+
54
+ plt.xlim([-box_width / 2 - 0.2, box_width / 2 + 0.2])
55
+ plt.ylim([-box_height / 2 - 0.2, box_height / 2 + 0.2])
56
+ plt.gca().set_aspect("equal", adjustable="box")
57
+ plt.xlabel("x position (m)")
58
+ plt.ylabel("y position (m)")
59
+ plt.title(f"Trajectory {idx}")
60
+ plt.legend()
61
+ plt.show()
62
+
63
+ def avoid_wall(self, position, hd, box_width, box_height):
64
+ """
65
+ Compute distance and angle to nearest wall
66
+ """
67
+ x = position[:, 0]
68
+ y = position[:, 1]
69
+ dists = [
70
+ box_width / 2 - x,
71
+ box_height / 2 - y,
72
+ box_width / 2 + x,
73
+ box_height / 2 + y,
74
+ ]
75
+ d_wall = np.min(dists, axis=0)
76
+ angles = np.arange(4) * np.pi / 2
77
+ theta = angles[np.argmin(dists, axis=0)]
78
+ hd = np.mod(hd, 2 * np.pi)
79
+ a_wall = hd - theta
80
+ a_wall = np.mod(a_wall + np.pi, 2 * np.pi) - np.pi
81
+
82
+ is_near_wall = (d_wall < self.border_region) * (np.abs(a_wall) < np.pi / 2)
83
+ turn_angle = np.zeros_like(hd)
84
+ turn_angle[is_near_wall] = np.sign(a_wall[is_near_wall]) * (
85
+ np.pi / 2 - np.abs(a_wall[is_near_wall])
86
+ )
87
+
88
+ return is_near_wall, turn_angle
89
+
90
+ def generate_trajectory(self, box_width, box_height, batch_size):
91
+ """Generate a random walk in a rectangular box"""
92
+ samples = self.options.sequence_length
93
+ dt = 0.02 # time step increment (seconds)
94
+ sigma = 5.76 * 2 # stdev rotation velocity (rads/sec)
95
+ b = 0.13 * 2 * np.pi # forward velocity rayleigh dist scale (m/sec)
96
+ mu = 0 # turn angle bias
97
+ self.border_region = 0.03 # meters
98
+
99
+ # Initialize variables
100
+ position = np.zeros([batch_size, samples + 2, 2])
101
+ head_dir = np.zeros([batch_size, samples + 2])
102
+ position[:, 0, 0] = np.random.uniform(-box_width / 2, box_width / 2, batch_size)
103
+ position[:, 0, 1] = np.random.uniform(
104
+ -box_height / 2, box_height / 2, batch_size
105
+ )
106
+ head_dir[:, 0] = np.random.uniform(0, 2 * np.pi, batch_size)
107
+ velocity = np.zeros([batch_size, samples + 2])
108
+
109
+ # Generate sequence of random boosts and turns
110
+ random_turn = np.random.normal(mu, sigma, [batch_size, samples + 1])
111
+ random_vel = np.random.rayleigh(b, [batch_size, samples + 1])
112
+ v = np.abs(np.random.normal(0, b * np.pi / 2, batch_size))
113
+
114
+ for t in range(samples + 1):
115
+ # Update velocity
116
+ v = random_vel[:, t]
117
+ turn_angle = np.zeros(batch_size)
118
+
119
+ if not self.options.periodic:
120
+ # If in border region, turn and slow down
121
+ is_near_wall, turn_angle = self.avoid_wall(
122
+ position[:, t], head_dir[:, t], box_width, box_height
123
+ )
124
+ v[is_near_wall] *= 0.25
125
+
126
+ # Update turn angle
127
+ turn_angle += dt * random_turn[:, t]
128
+
129
+ # Take a step
130
+ velocity[:, t] = v * dt
131
+ update = velocity[:, t, None] * np.stack(
132
+ [np.cos(head_dir[:, t]), np.sin(head_dir[:, t])], axis=-1
133
+ )
134
+ position[:, t + 1] = position[:, t] + update
135
+
136
+ # Rotate head direction
137
+ head_dir[:, t + 1] = head_dir[:, t] + turn_angle
138
+
139
+ # Periodic boundaries
140
+ if self.options.periodic:
141
+ position[:, :, 0] = (
142
+ np.mod(position[:, :, 0] + box_width / 2, box_width) - box_width / 2
143
+ )
144
+ position[:, :, 1] = (
145
+ np.mod(position[:, :, 1] + box_height / 2, box_height) - box_height / 2
146
+ )
147
+
148
+ head_dir = np.mod(head_dir + np.pi, 2 * np.pi) - np.pi # Periodic variable
149
+
150
+ traj = {}
151
+ # Input variables
152
+ traj["init_hd"] = head_dir[:, 0, None]
153
+ traj["init_x"] = position[:, 1, 0, None]
154
+ traj["init_y"] = position[:, 1, 1, None]
155
+
156
+ traj["ego_v"] = velocity[:, 1:-1]
157
+ ang_v = np.diff(head_dir, axis=-1)
158
+ traj["phi_x"], traj["phi_y"] = np.cos(ang_v)[:, :-1], np.sin(ang_v)[:, :-1]
159
+
160
+ # Target variables
161
+ traj["target_hd"] = head_dir[:, 1:-1]
162
+ traj["target_x"] = position[:, 2:, 0]
163
+ traj["target_y"] = position[:, 2:, 1]
164
+
165
+ # for i in range(5):
166
+ # self.plot_trajectory(traj, box_width, box_height, i)
167
+ # raise Exception("dog")
168
+
169
+ return traj
170
+
171
+ def get_generator(self, batch_size=None, box_width=None, box_height=None):
172
+ """
173
+ Returns a generator that yields batches of trajectories
174
+ """
175
+ if not batch_size:
176
+ batch_size = self.options.batch_size
177
+ if not box_width:
178
+ box_width = self.options.box_width
179
+ if not box_height:
180
+ box_height = self.options.box_height
181
+
182
+ while True:
183
+ traj = self.generate_trajectory(box_width, box_height, batch_size)
184
+
185
+ v = np.stack(
186
+ [
187
+ traj["ego_v"] * np.cos(traj["target_hd"]),
188
+ traj["ego_v"] * np.sin(traj["target_hd"]),
189
+ ],
190
+ axis=-1,
191
+ )
192
+ v = torch.tensor(v, dtype=torch.float32).transpose(0, 1)
193
+
194
+ pos = np.stack([traj["target_x"], traj["target_y"]], axis=-1)
195
+ pos = torch.tensor(pos, dtype=torch.float32).transpose(0, 1)
196
+ # Put on GPU if GPU is available
197
+ pos = pos.to(self.options.device)
198
+ place_outputs = self.place_cells.get_activation(pos)
199
+
200
+ init_pos = np.stack([traj["init_x"], traj["init_y"]], axis=-1)
201
+ init_pos = torch.tensor(init_pos, dtype=torch.float32)
202
+ init_pos = init_pos.to(self.options.device)
203
+ init_actv = self.place_cells.get_activation(init_pos).squeeze()
204
+
205
+ v = v.to(self.options.device)
206
+ inputs = (v, init_actv)
207
+
208
+ yield (inputs, place_outputs, pos)
209
+
210
+ def get_test_batch(self, batch_size=None, box_width=None, box_height=None):
211
+ """For testing performance, returns a batch of smample trajectories"""
212
+ if not batch_size:
213
+ batch_size = self.options.batch_size
214
+ if not box_width:
215
+ box_width = self.options.box_width
216
+ if not box_height:
217
+ box_height = self.options.box_height
218
+
219
+ traj = self.generate_trajectory(box_width, box_height, batch_size)
220
+
221
+ v = np.stack(
222
+ [
223
+ traj["ego_v"] * np.cos(traj["target_hd"]),
224
+ traj["ego_v"] * np.sin(traj["target_hd"]),
225
+ ],
226
+ axis=-1,
227
+ )
228
+ v = torch.tensor(v, dtype=torch.float32).transpose(0, 1)
229
+
230
+ pos = np.stack([traj["target_x"], traj["target_y"]], axis=-1)
231
+ pos = torch.tensor(pos, dtype=torch.float32).transpose(0, 1)
232
+ pos = pos.to(self.options.device)
233
+ place_outputs = self.place_cells.get_activation(pos)
234
+
235
+ init_pos = np.stack([traj["init_x"], traj["init_y"]], axis=-1)
236
+ init_pos = torch.tensor(init_pos, dtype=torch.float32)
237
+ init_pos = init_pos.to(self.options.device)
238
+ init_actv = self.place_cells.get_activation(init_pos).squeeze()
239
+
240
+ v = v.to(self.options.device)
241
+ inputs = (v, init_actv)
242
+
243
+ return (inputs, pos, place_outputs)
utils.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def generate_run_ID(options):
5
+ '''
6
+ Create a unique run ID from the most relevant
7
+ parameters. Remaining parameters can be found in
8
+ params.npy file.
9
+ '''
10
+ params = [
11
+ 'steps', str(options.sequence_length),
12
+ 'batch', str(options.batch_size),
13
+ options.RNN_type,
14
+ str(options.Ng),
15
+ options.activation,
16
+ 'rf', str(options.place_cell_rf),
17
+ 'DoG', str(options.DoG),
18
+ 'periodic', str(options.periodic),
19
+ 'lr', str(options.learning_rate),
20
+ 'weight_decay', str(options.weight_decay),
21
+ ]
22
+ separator = '_'
23
+ run_ID = separator.join(params)
24
+ run_ID = run_ID.replace('.', '')
25
+
26
+ return run_ID
27
+
28
+
29
+ def get_2d_sort(x1,x2):
30
+ """
31
+ Reshapes x1 and x2 into square arrays, and then sorts
32
+ them such that x1 increases downward and x2 increases
33
+ rightward. Returns the order.
34
+ """
35
+ n = int(np.round(np.sqrt(len(x1))))
36
+ total_order = x1.argsort()
37
+ total_order = total_order.reshape(n,n)
38
+ for i in range(n):
39
+ row_order = x2[total_order.ravel()].reshape(n,n)[i].argsort()
40
+ total_order[i] = total_order[i,row_order]
41
+ total_order = total_order.ravel()
42
+ return total_order
43
+
44
+
45
+ def dft(N,real=False,scale='sqrtn'):
46
+ if not real:
47
+ return scipy.linalg.dft(N,scale)
48
+ else:
49
+ cosines = np.cos(2*np.pi*np.arange(N//2+1)[None,:]/N*np.arange(N)[:,None])
50
+ sines = np.sin(2*np.pi*np.arange(1,(N-1)//2+1)[None,:]/N*np.arange(N)[:,None])
51
+ if N%2==0:
52
+ cosines[:,-1] /= np.sqrt(2)
53
+ F = np.concatenate((cosines,sines[:,::-1]),1)
54
+ F[:,0] /= np.sqrt(N)
55
+ F[:,1:] /= np.sqrt(N/2)
56
+ return F
57
+
58
+
59
+ def skaggs_power(Jsort):
60
+ F = dft(int(np.sqrt(N)), real=True)
61
+ F2d = F[:,None,:,None]*F[None,:,None,:]
62
+
63
+ F2d_unroll = np.reshape(F2d, (N, N))
64
+
65
+ F2d_inv = F2d_unroll.conj().T
66
+ Jtilde = F2d_inv.dot(Jsort).dot(F2d_unroll)
67
+
68
+ return (Jtilde[1,1]**2 + Jtilde[-1,-1]**2) / (Jtilde**2).sum()
69
+
70
+
71
+ def skaggs_power_2(Jsort):
72
+ J_square = np.reshape(Jsort, (n,n,n,n))
73
+ Jmean = np.zeros([n,n])
74
+ for i in range(n):
75
+ for j in range(n):
76
+ Jmean += np.roll(np.roll(J_square[i,j], -i, axis=0), -j, axis=1)
77
+
78
+ # Jmean[0,0] = np.max(Jmean[1:,1:])
79
+ Jmean = np.roll(np.roll(Jmean, n//2, axis=0), n//2, axis=1)
80
+ Jtilde = np.real(np.fft.fft2(Jmean))
81
+
82
+ Jtilde[0,0] = 0
83
+ sk_power = Jtilde[1,1]**2 + Jtilde[0,1]**2 + Jtilde[1,0]**2
84
+ sk_power += Jtilde[-1,-1]**2 + Jtilde[0,-1]**2 + Jtilde[-1,0]**2
85
+ sk_power /= (Jtilde**2).sum()
86
+
87
+ return sk_power
88
+
89
+
90
+ def calc_err():
91
+ inputs, _, pos = next(gen)
92
+ pred = model(inputs)
93
+ pred_pos = place_cells.get_nearest_cell_pos(pred)
94
+ return tf.reduce_mean(tf.sqrt(tf.reduce_sum((pos - pred_pos)**2, axis=-1)))
95
+
96
+ from visualize import compute_ratemaps, plot_ratemaps
97
+
98
+
99
+ def compute_variance(res, n_avg):
100
+
101
+ activations, rate_map, g, pos = compute_ratemaps(model, data_manager, options, res=res, n_avg=n_avg)
102
+
103
+ counts = np.zeros([res,res])
104
+ variance = np.zeros([res,res])
105
+
106
+ x_all = (pos[:,0] + options['box_width']/2) / options['box_width'] * res
107
+ y_all = (pos[:,1] + options['box_height']/2) / options['box_height'] * res
108
+ for i in tqdm(range(len(g))):
109
+ x = int(x_all[i])
110
+ y = int(y_all[i])
111
+ if x >=0 and x < res and y >=0 and y < res:
112
+ counts[x, y] += 1
113
+ variance[x, y] += np.linalg.norm(g[i] - activations[:, x, y]) / np.linalg.norm(g[i]) / np.linalg.norm(activations[:,x,y])
114
+
115
+ for x in range(res):
116
+ for y in range(res):
117
+ if counts[x, y] > 0:
118
+ variance[x, y] /= counts[x, y]
119
+
120
+ return variance
121
+
122
+
123
+ def load_trained_weights(model, trainer, weight_dir):
124
+ ''' Load weights stored as a .npy file (for github)'''
125
+
126
+ # Train for a single step to initialize weights
127
+ # trainer.train(n_epochs=1, n_steps=1, save=False)
128
+
129
+ # Load weights from npy array
130
+ weights = np.load(weight_dir, allow_pickle=True)
131
+ model.set_weights(weights)
132
+ print('Loaded trained weights.')
133
+
visualize.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import numpy as np
3
+ from matplotlib import pyplot as plt
4
+
5
+ import scipy
6
+ import scipy.stats
7
+ from imageio import imsave
8
+ import cv2
9
+
10
+
11
+ def concat_images(images, image_width, spacer_size):
12
+ """ Concat image horizontally with spacer """
13
+ spacer = np.ones([image_width, spacer_size, 4], dtype=np.uint8) * 255
14
+ images_with_spacers = []
15
+
16
+ image_size = len(images)
17
+
18
+ for i in range(image_size):
19
+ images_with_spacers.append(images[i])
20
+ if i != image_size - 1:
21
+ # Add spacer
22
+ images_with_spacers.append(spacer)
23
+ ret = np.hstack(images_with_spacers)
24
+ return ret
25
+
26
+
27
+ def concat_images_in_rows(images, row_size, image_width, spacer_size=4):
28
+ """ Concat images in rows """
29
+ column_size = len(images) // row_size
30
+ spacer_h = np.ones([spacer_size, image_width*column_size + (column_size-1)*spacer_size, 4],
31
+ dtype=np.uint8) * 255
32
+
33
+ row_images_with_spacers = []
34
+
35
+ for row in range(row_size):
36
+ row_images = images[column_size*row:column_size*row+column_size]
37
+ row_concated_images = concat_images(row_images, image_width, spacer_size)
38
+ row_images_with_spacers.append(row_concated_images)
39
+
40
+ if row != row_size-1:
41
+ row_images_with_spacers.append(spacer_h)
42
+
43
+ ret = np.vstack(row_images_with_spacers)
44
+ return ret
45
+
46
+
47
+ def convert_to_colormap(im, cmap):
48
+ im = cmap(im)
49
+ im = np.uint8(im * 255)
50
+ return im
51
+
52
+
53
+ def rgb(im, cmap='jet', smooth=True):
54
+ cmap = plt.cm.get_cmap(cmap)
55
+ np.seterr(invalid='ignore') # ignore divide by zero err
56
+ im = (im - np.min(im)) / (np.max(im) - np.min(im))
57
+ if smooth:
58
+ im = cv2.GaussianBlur(im, (3,3), sigmaX=1, sigmaY=0)
59
+ im = cmap(im)
60
+ im = np.uint8(im * 255)
61
+ return im
62
+
63
+
64
+ def plot_ratemaps(activations, n_plots, cmap='jet', smooth=True, width=16):
65
+ images = [rgb(im, cmap, smooth) for im in activations[:n_plots]]
66
+ rm_fig = concat_images_in_rows(images, n_plots//width, activations.shape[-1])
67
+ return rm_fig
68
+
69
+
70
+ def compute_ratemaps(model, trajectory_generator, options, res=20, n_avg=None, Ng=512, idxs=None, return_raw=False):
71
+ '''Compute spatial firing fields
72
+
73
+ Args:
74
+ model: The RNN model
75
+ trajectory_generator: Generator for test trajectories
76
+ options: Training options
77
+ res: Resolution of the rate map grid
78
+ n_avg: Number of batches to average over
79
+ Ng: Number of grid cells to analyze
80
+ idxs: Indices of specific grid cells to analyze
81
+ return_raw: If True, also return raw activations (g) and positions (pos).
82
+ Warning: This uses significant memory for large batch_size/n_avg.
83
+ If False, returns None for g and pos to save memory.
84
+
85
+ Returns:
86
+ activations: Spatial firing fields [Ng, res, res]
87
+ rate_map: Flattened rate maps [Ng, res*res]
88
+ g: Raw activations (None if return_raw=False)
89
+ pos: Raw positions (None if return_raw=False)
90
+ '''
91
+
92
+ if not n_avg:
93
+ n_avg = 1000 // options.sequence_length
94
+
95
+ if not np.any(idxs):
96
+ idxs = np.arange(Ng)
97
+ idxs = idxs[:Ng]
98
+
99
+ # Only allocate large arrays if return_raw is True
100
+ if return_raw:
101
+ g = np.zeros([n_avg, options.batch_size * options.sequence_length, Ng])
102
+ pos = np.zeros([n_avg, options.batch_size * options.sequence_length, 2])
103
+ else:
104
+ g = None
105
+ pos = None
106
+
107
+ activations = np.zeros([Ng, res, res])
108
+ counts = np.zeros([res, res])
109
+
110
+ for index in range(n_avg):
111
+ inputs, pos_batch, _ = trajectory_generator.get_test_batch()
112
+ g_batch = model.g(inputs).detach().cpu().numpy()
113
+
114
+ pos_batch = np.reshape(pos_batch.cpu(), [-1, 2])
115
+ g_batch = g_batch[:,:,idxs].reshape(-1, Ng)
116
+
117
+ if return_raw:
118
+ g[index] = g_batch
119
+ pos[index] = pos_batch
120
+
121
+ x_batch = (pos_batch[:,0] + options.box_width/2) / (options.box_width) * res
122
+ y_batch = (pos_batch[:,1] + options.box_height/2) / (options.box_height) * res
123
+
124
+ for i in range(options.batch_size*options.sequence_length):
125
+ x = x_batch[i]
126
+ y = y_batch[i]
127
+ if x >=0 and x < res and y >=0 and y < res:
128
+ counts[int(x), int(y)] += 1
129
+ activations[:, int(x), int(y)] += g_batch[i, :]
130
+
131
+ for x in range(res):
132
+ for y in range(res):
133
+ if counts[x, y] > 0:
134
+ activations[:, x, y] /= counts[x, y]
135
+
136
+ if return_raw:
137
+ g = g.reshape([-1, Ng])
138
+ pos = pos.reshape([-1, 2])
139
+
140
+ # # scipy binned_statistic_2d is slightly slower
141
+ # activations = scipy.stats.binned_statistic_2d(pos[:,0], pos[:,1], g.T, bins=res)[0]
142
+ rate_map = activations.reshape(Ng, -1)
143
+
144
+ return activations, rate_map, g, pos
145
+
146
+
147
+ def save_ratemaps(model, trajectory_generator, options, step, res=20, n_avg=None):
148
+ if not n_avg:
149
+ n_avg = 1000 // options.sequence_length
150
+ activations, rate_map, g, pos = compute_ratemaps(model, trajectory_generator,
151
+ options, res=res, n_avg=n_avg)
152
+ rm_fig = plot_ratemaps(activations, n_plots=len(activations))
153
+ imdir = options.save_dir + "/" + options.run_ID
154
+ imsave(imdir + "/" + str(step) + ".png", rm_fig)
155
+
156
+
157
+ def save_autocorr(sess, model, save_name, trajectory_generator, step, flags):
158
+ starts = [0.2] * 10
159
+ ends = np.linspace(0.4, 1.0, num=10)
160
+ coord_range=((-1.1, 1.1), (-1.1, 1.1))
161
+ masks_parameters = zip(starts, ends.tolist())
162
+ latest_epoch_scorer = scores.GridScorer(20, coord_range, masks_parameters)
163
+
164
+ res = dict()
165
+ index_size = 100
166
+ for _ in range(index_size):
167
+ feed_dict = trajectory_generator.feed_dict(flags.box_width, flags.box_height)
168
+ mb_res = sess.run({
169
+ 'pos_xy': model.target_pos,
170
+ 'bottleneck': model.g,
171
+ }, feed_dict=feed_dict)
172
+ res = utils.concat_dict(res, mb_res)
173
+
174
+ filename = save_name + '/autocorrs_' + str(step) + '.pdf'
175
+ imdir = flags.save_dir + '/'
176
+ out = utils.get_scores_and_plot(
177
+ latest_epoch_scorer, res['pos_xy'], res['bottleneck'],
178
+ imdir, filename)