Upload folder using huggingface_hub
Browse files- .gitattributes +11 -35
- .gitignore +123 -0
- LICENSE +201 -0
- README.md +69 -3
- docs/LSTM_hexagons.png +3 -0
- docs/RNNgrids.png +3 -0
- docs/poisson_spiking.gif +3 -0
- main.py +103 -0
- model.py +98 -0
- models/example_pc_centers.npy +3 -0
- models/example_trained_weights.npy +3 -0
- models/steps_20_batch_200_RNN_4096_relu_rf_012_DoG_True_periodic_False_lr_00001_weight_decay_00001/1.png +3 -0
- models/steps_20_batch_200_RNN_4096_relu_rf_012_DoG_True_periodic_False_lr_00001_weight_decay_00001/most_recent_model.pth +3 -0
- place_cells.py +124 -0
- pyproject.toml +22 -0
- requirements.txt +8 -0
- scores.py +272 -0
- src/grid_pattern_formation/__init__.py +0 -0
- trainer.py +120 -0
- trajectory_generator.py +243 -0
- utils.py +133 -0
- visualize.py +178 -0
.gitattributes
CHANGED
|
@@ -1,35 +1,11 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 1 |
+
# Auto detect text files and perform LF normalization
|
| 2 |
+
* text=auto
|
| 3 |
+
# SCM syntax highlighting & preventing 3-way merges
|
| 4 |
+
pixi.lock merge=binary linguist-language=YAML linguist-generated=true
|
| 5 |
+
docs/LSTM_hexagons.png filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
docs/RNNgrids.png filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
docs/poisson_spiking.gif filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
models/example_pc_centers.npy filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
models/example_trained_weights.npy filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
models/steps_20_batch_200_RNN_4096_relu_rf_012_DoG_True_periodic_False_lr_00001_weight_decay_00001/1.png filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
models/steps_20_batch_200_RNN_4096_relu_rf_012_DoG_True_periodic_False_lr_00001_weight_decay_00001/most_recent_model.pth filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
images/
|
| 7 |
+
movies/
|
| 8 |
+
data/
|
| 9 |
+
fixed_point_finder_data/
|
| 10 |
+
ben_decimate_1000_step/
|
| 11 |
+
ben_true_100_step/
|
| 12 |
+
ben_100*/
|
| 13 |
+
data/
|
| 14 |
+
nohup.out
|
| 15 |
+
|
| 16 |
+
# C extensions
|
| 17 |
+
*.so
|
| 18 |
+
|
| 19 |
+
# Distribution / packaging
|
| 20 |
+
.Python
|
| 21 |
+
build/
|
| 22 |
+
develop-eggs/
|
| 23 |
+
dist/
|
| 24 |
+
downloads/
|
| 25 |
+
eggs/
|
| 26 |
+
.eggs/
|
| 27 |
+
lib/
|
| 28 |
+
lib64/
|
| 29 |
+
parts/
|
| 30 |
+
sdist/
|
| 31 |
+
var/
|
| 32 |
+
wheels/
|
| 33 |
+
*.egg-info/
|
| 34 |
+
.installed.cfg
|
| 35 |
+
*.egg
|
| 36 |
+
MANIFEST
|
| 37 |
+
|
| 38 |
+
# PyInstaller
|
| 39 |
+
# Usually these files are written by a python script from a template
|
| 40 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 41 |
+
*.manifest
|
| 42 |
+
*.spec
|
| 43 |
+
|
| 44 |
+
# Installer logs
|
| 45 |
+
pip-log.txt
|
| 46 |
+
pip-delete-this-directory.txt
|
| 47 |
+
|
| 48 |
+
# Unit test / coverage reports
|
| 49 |
+
htmlcov/
|
| 50 |
+
.tox/
|
| 51 |
+
.coverage
|
| 52 |
+
.coverage.*
|
| 53 |
+
.cache
|
| 54 |
+
nosetests.xml
|
| 55 |
+
coverage.xml
|
| 56 |
+
*.cover
|
| 57 |
+
.hypothesis/
|
| 58 |
+
.pytest_cache/
|
| 59 |
+
|
| 60 |
+
# Translations
|
| 61 |
+
*.mo
|
| 62 |
+
*.pot
|
| 63 |
+
|
| 64 |
+
# Django stuff:
|
| 65 |
+
*.log
|
| 66 |
+
local_settings.py
|
| 67 |
+
db.sqlite3
|
| 68 |
+
|
| 69 |
+
# Flask stuff:
|
| 70 |
+
instance/
|
| 71 |
+
.webassets-cache
|
| 72 |
+
|
| 73 |
+
# Scrapy stuff:
|
| 74 |
+
.scrapy
|
| 75 |
+
|
| 76 |
+
# Sphinx documentation
|
| 77 |
+
docs/_build/
|
| 78 |
+
|
| 79 |
+
# PyBuilder
|
| 80 |
+
target/
|
| 81 |
+
|
| 82 |
+
# Jupyter Notebook
|
| 83 |
+
.ipynb_checkpoints
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
.python-version
|
| 87 |
+
|
| 88 |
+
# celery beat schedule file
|
| 89 |
+
celerybeat-schedule
|
| 90 |
+
|
| 91 |
+
# SageMath parsed files
|
| 92 |
+
*.sage.py
|
| 93 |
+
|
| 94 |
+
# Environments
|
| 95 |
+
.env
|
| 96 |
+
.venv
|
| 97 |
+
env/
|
| 98 |
+
venv/
|
| 99 |
+
ENV/
|
| 100 |
+
env.bak/
|
| 101 |
+
venv.bak/
|
| 102 |
+
|
| 103 |
+
# Spyder project settings
|
| 104 |
+
.spyderproject
|
| 105 |
+
.spyproject
|
| 106 |
+
|
| 107 |
+
# Rope project settings
|
| 108 |
+
.ropeproject
|
| 109 |
+
|
| 110 |
+
# mkdocs documentation
|
| 111 |
+
/site
|
| 112 |
+
|
| 113 |
+
# mypy
|
| 114 |
+
.mypy_cache/
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# by user
|
| 118 |
+
saved
|
| 119 |
+
*.ipynb
|
| 120 |
+
*.zip
|
| 121 |
+
# pixi environments
|
| 122 |
+
.pixi/*
|
| 123 |
+
!.pixi/config.toml
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
CHANGED
|
@@ -1,3 +1,69 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[](https://zenodo.org/badge/latestdoi/217773694)
|
| 2 |
+
|
| 3 |
+
# Grid cells in RNNs trained to path integrate
|
| 4 |
+
|
| 5 |
+
Code to reproduce the trained RNN in [**a unified theory for the origin of grid cells through the lens of pattern formation (NeurIPS '19)**](https://papers.nips.cc/paper/9191-a-unified-theory-for-the-origin-of-grid-cells-through-the-lens-of-pattern-formation) and additional analysis described in this [**preprint**](https://www.biorxiv.org/content/10.1101/2020.12.29.424583v1).
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
Quick start:
|
| 9 |
+
|
| 10 |
+
<img src="./docs/poisson_spiking.gif" width="300" align="right">
|
| 11 |
+
|
| 12 |
+
* [**inspect_model.ipynb**](inspect_model.ipynb):
|
| 13 |
+
Train a model and visualize its hidden unit ratemaps.
|
| 14 |
+
|
| 15 |
+
* [**main.py**](main.py):
|
| 16 |
+
or, train a model from the command line.
|
| 17 |
+
|
| 18 |
+
* [**pattern_formation.ipynb**](pattern_formation.ipynb):
|
| 19 |
+
Numerical simulations of pattern-forming dynamics.
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
Includes:
|
| 23 |
+
|
| 24 |
+
* [**trajectory_generator.py**](trajectory_generator.py):
|
| 25 |
+
Generate simulated rat trajectories in a rectangular environment.
|
| 26 |
+
|
| 27 |
+
* [**place_cells.py**](place_cells.py):
|
| 28 |
+
Tile a set of simulated place cells across the training environment.
|
| 29 |
+
|
| 30 |
+
* [**model.py**](model.py):
|
| 31 |
+
Contains the vanilla RNN model architecture, as well as an LSTM.
|
| 32 |
+
|
| 33 |
+
* [**trainer.py**](model.py):
|
| 34 |
+
Contains model training loop.
|
| 35 |
+
|
| 36 |
+
* [**models/example_trained_weights.npy**](models/example_trained_weights.npy)
|
| 37 |
+
Contains a set of pre-trained weights.
|
| 38 |
+
|
| 39 |
+
## Running
|
| 40 |
+
|
| 41 |
+
We recommend creating a virtual environment:
|
| 42 |
+
|
| 43 |
+
```shell
|
| 44 |
+
$ virtualenv env
|
| 45 |
+
$ source env/bin/activate
|
| 46 |
+
$ pip install --upgrade pip
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
Then, install the dependencies automatically with `pip install -r requirements.txt`
|
| 50 |
+
or manually with:
|
| 51 |
+
|
| 52 |
+
```shell
|
| 53 |
+
$ pip install --upgrade numpy==1.17.2
|
| 54 |
+
$ pip install --upgrade tensorflow==2.0.0rc2
|
| 55 |
+
$ pip install --upgrade scipy==1.4.1
|
| 56 |
+
$ pip install --upgrade matplotlib==3.0.3
|
| 57 |
+
$ pip install --upgrade imageio==2.5.0
|
| 58 |
+
$ pip install --upgrade opencv-python==4.1.1.26
|
| 59 |
+
$ pip install --upgrade tqdm==4.36.0
|
| 60 |
+
$ pip install --upgrade opencv-python==4.1.1.26
|
| 61 |
+
$ pip install --upgrade torch==1.10.0
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
If you want to train your own models, make sure to properly set the default
|
| 65 |
+
save directory in `main.py`!
|
| 66 |
+
|
| 67 |
+
## Result
|
| 68 |
+
|
| 69 |
+

|
docs/LSTM_hexagons.png
ADDED
|
Git LFS Details
|
docs/RNNgrids.png
ADDED
|
Git LFS Details
|
docs/poisson_spiking.gif
ADDED
|
Git LFS Details
|
main.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import tensorflow as tf
|
| 3 |
+
import torch.cuda
|
| 4 |
+
import argparse
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from utils import generate_run_ID
|
| 8 |
+
from place_cells import PlaceCells
|
| 9 |
+
from trajectory_generator import TrajectoryGenerator
|
| 10 |
+
from model import RNN
|
| 11 |
+
from trainer import Trainer
|
| 12 |
+
|
| 13 |
+
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
|
| 14 |
+
|
| 15 |
+
parser = argparse.ArgumentParser()
|
| 16 |
+
parser.add_argument(
|
| 17 |
+
"--save_dir",
|
| 18 |
+
# default='/mnt/fs2/bsorsch/grid_cells/models/',
|
| 19 |
+
default="models/",
|
| 20 |
+
help="directory to save trained models",
|
| 21 |
+
)
|
| 22 |
+
parser.add_argument(
|
| 23 |
+
"--n_epochs", default=100, help="number of training epochs", type=int
|
| 24 |
+
)
|
| 25 |
+
parser.add_argument("--n_steps", default=1000, help="batches per epoch", type=int)
|
| 26 |
+
parser.add_argument(
|
| 27 |
+
"--batch_size", default=200, help="number of trajectories per batch", type=int
|
| 28 |
+
)
|
| 29 |
+
parser.add_argument(
|
| 30 |
+
"--sequence_length", default=20, help="number of steps in trajectory", type=int
|
| 31 |
+
)
|
| 32 |
+
parser.add_argument(
|
| 33 |
+
"--learning_rate", default=1e-4, help="gradient descent learning rate", type=float
|
| 34 |
+
)
|
| 35 |
+
parser.add_argument("--Np", default=512, help="number of place cells", type=int)
|
| 36 |
+
parser.add_argument("--Ng", default=4096, help="number of grid cells", type=int)
|
| 37 |
+
parser.add_argument(
|
| 38 |
+
"--place_cell_rf",
|
| 39 |
+
default=0.12,
|
| 40 |
+
help="width of place cell center tuning curve (m)",
|
| 41 |
+
type=float,
|
| 42 |
+
)
|
| 43 |
+
parser.add_argument(
|
| 44 |
+
"--surround_scale",
|
| 45 |
+
default=2,
|
| 46 |
+
help="if DoG, ratio of sigma2^2 to sigma1^2",
|
| 47 |
+
type=int,
|
| 48 |
+
)
|
| 49 |
+
parser.add_argument("--RNN_type", default="RNN", help="RNN or LSTM")
|
| 50 |
+
parser.add_argument("--activation", default="relu", help="recurrent nonlinearity")
|
| 51 |
+
parser.add_argument(
|
| 52 |
+
"--weight_decay",
|
| 53 |
+
default=1e-4,
|
| 54 |
+
help="strength of weight decay on recurrent weights",
|
| 55 |
+
type=float,
|
| 56 |
+
)
|
| 57 |
+
parser.add_argument(
|
| 58 |
+
"--DoG", default=True, help="use difference of gaussians tuning curves"
|
| 59 |
+
)
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
"--periodic", default=False, help="trajectories with periodic boundary conditions"
|
| 62 |
+
)
|
| 63 |
+
parser.add_argument(
|
| 64 |
+
"--box_width", default=2.2, help="width of training environment", type=float
|
| 65 |
+
)
|
| 66 |
+
parser.add_argument(
|
| 67 |
+
"--box_height", default=2.2, help="height of training environment", type=float
|
| 68 |
+
)
|
| 69 |
+
parser.add_argument(
|
| 70 |
+
"--device",
|
| 71 |
+
default="cuda" if torch.cuda.is_available() else "cpu",
|
| 72 |
+
help="device to use for training",
|
| 73 |
+
)
|
| 74 |
+
parser.add_argument(
|
| 75 |
+
"--seed", default=None, help="seed number for all numpy random number generator"
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
options = parser.parse_args()
|
| 79 |
+
options.run_ID = generate_run_ID(options)
|
| 80 |
+
|
| 81 |
+
print(f"Using device: {options.device}")
|
| 82 |
+
|
| 83 |
+
if options.seed:
|
| 84 |
+
np.random.seed(int(options.seed))
|
| 85 |
+
|
| 86 |
+
place_cells = PlaceCells(options)
|
| 87 |
+
if options.RNN_type == "RNN":
|
| 88 |
+
model = RNN(options, place_cells)
|
| 89 |
+
elif options.RNN_type == "LSTM":
|
| 90 |
+
# model = LSTM(options, place_cells)
|
| 91 |
+
raise NotImplementedError
|
| 92 |
+
|
| 93 |
+
# Put model on GPU if using GPU
|
| 94 |
+
if options.device == "cuda":
|
| 95 |
+
print("Using CUDA")
|
| 96 |
+
model = model.to(options.device)
|
| 97 |
+
|
| 98 |
+
trajectory_generator = TrajectoryGenerator(options, place_cells)
|
| 99 |
+
|
| 100 |
+
trainer = Trainer(options, model, trajectory_generator)
|
| 101 |
+
|
| 102 |
+
# Train
|
| 103 |
+
trainer.train(n_epochs=options.n_epochs, n_steps=options.n_steps)
|
model.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class RNN(torch.nn.Module):
|
| 6 |
+
def __init__(self, options, place_cells):
|
| 7 |
+
super(RNN, self).__init__()
|
| 8 |
+
self.Ng = options.Ng
|
| 9 |
+
self.Np = options.Np
|
| 10 |
+
self.sequence_length = options.sequence_length
|
| 11 |
+
self.weight_decay = options.weight_decay
|
| 12 |
+
self.place_cells = place_cells
|
| 13 |
+
|
| 14 |
+
# Input weights
|
| 15 |
+
self.encoder = torch.nn.Linear(self.Np, self.Ng, bias=False)
|
| 16 |
+
self.RNN = torch.nn.RNN(
|
| 17 |
+
input_size=2,
|
| 18 |
+
hidden_size=self.Ng,
|
| 19 |
+
nonlinearity=options.activation,
|
| 20 |
+
bias=False,
|
| 21 |
+
)
|
| 22 |
+
# Linear read-out weights
|
| 23 |
+
self.decoder = torch.nn.Linear(self.Ng, self.Np, bias=False)
|
| 24 |
+
|
| 25 |
+
self.softmax = torch.nn.Softmax(dim=-1)
|
| 26 |
+
|
| 27 |
+
def g(self, inputs):
|
| 28 |
+
"""
|
| 29 |
+
Compute grid cell activations.
|
| 30 |
+
Args:
|
| 31 |
+
inputs: Batch of 2d velocity inputs with shape [batch_size, sequence_length, 2].
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
g: Batch of grid cell activations with shape [batch_size, sequence_length, Ng].
|
| 35 |
+
"""
|
| 36 |
+
v, p0 = inputs
|
| 37 |
+
init_state = self.encoder(p0)[None]
|
| 38 |
+
g, _ = self.RNN(v, init_state)
|
| 39 |
+
return g
|
| 40 |
+
|
| 41 |
+
def predict(self, inputs):
|
| 42 |
+
"""
|
| 43 |
+
Predict place cell code.
|
| 44 |
+
Args:
|
| 45 |
+
inputs: Batch of 2d velocity inputs with shape [batch_size, sequence_length, 2].
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
place_preds: Predicted place cell activations with shape
|
| 49 |
+
[batch_size, sequence_length, Np].
|
| 50 |
+
"""
|
| 51 |
+
place_preds = self.decoder(self.g(inputs))
|
| 52 |
+
|
| 53 |
+
return place_preds
|
| 54 |
+
|
| 55 |
+
def set_weights(self, weights):
|
| 56 |
+
"""
|
| 57 |
+
Load weights from a numpy array (e.g. from the provided example weights).
|
| 58 |
+
Assumes weights are in the order: [encoder, rnn_ih, rnn_hh, decoder]
|
| 59 |
+
and transposed (TF/Keras format).
|
| 60 |
+
"""
|
| 61 |
+
with torch.no_grad():
|
| 62 |
+
# Encoder: (Np, Ng) -> (Ng, Np)
|
| 63 |
+
self.encoder.weight.copy_(torch.from_numpy(weights[0].T).float())
|
| 64 |
+
|
| 65 |
+
# RNN input: (2, Ng) -> (Ng, 2)
|
| 66 |
+
self.RNN.weight_ih_l0.copy_(torch.from_numpy(weights[1].T).float())
|
| 67 |
+
|
| 68 |
+
# RNN hidden: (Ng, Ng) -> (Ng, Ng)
|
| 69 |
+
self.RNN.weight_hh_l0.copy_(torch.from_numpy(weights[2].T).float())
|
| 70 |
+
|
| 71 |
+
# Decoder: (Ng, Np) -> (Np, Ng)
|
| 72 |
+
self.decoder.weight.copy_(torch.from_numpy(weights[3].T).float())
|
| 73 |
+
|
| 74 |
+
def compute_loss(self, inputs, pc_outputs, pos):
|
| 75 |
+
"""
|
| 76 |
+
Compute avg. loss and decoding error.
|
| 77 |
+
Args:
|
| 78 |
+
inputs: Batch of 2d velocity inputs with shape [batch_size, sequence_length, 2].
|
| 79 |
+
pc_outputs: Ground truth place cell activations with shape
|
| 80 |
+
[batch_size, sequence_length, Np].
|
| 81 |
+
pos: Ground truth 2d position with shape [batch_size, sequence_length, 2].
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
loss: Avg. loss for this training batch.
|
| 85 |
+
err: Avg. decoded position error in cm.
|
| 86 |
+
"""
|
| 87 |
+
y: torch.Tensor = pc_outputs
|
| 88 |
+
preds: torch.Tensor = self.predict(inputs)
|
| 89 |
+
loss = torch.nn.functional.cross_entropy(preds.flatten(0, 1), y.flatten(0, 1))
|
| 90 |
+
|
| 91 |
+
# Weight regularization
|
| 92 |
+
loss += self.weight_decay * (self.RNN.weight_hh_l0**2).sum()
|
| 93 |
+
|
| 94 |
+
# Compute decoding error
|
| 95 |
+
pred_pos = self.place_cells.get_nearest_cell_pos(preds)
|
| 96 |
+
err = torch.sqrt(((pos - pred_pos) ** 2).sum(-1)).mean()
|
| 97 |
+
|
| 98 |
+
return loss, err
|
models/example_pc_centers.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cafa64a4894d5f1c4aaf4befdf036fb43e9c9ed2f2825b79feb02ad1f7199997
|
| 3 |
+
size 4224
|
models/example_trained_weights.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5277a433615f00894ffe13d55b46d5becda9b362e099c7c73a6191e61bf51ff2
|
| 3 |
+
size 83919352
|
models/steps_20_batch_200_RNN_4096_relu_rf_012_DoG_True_periodic_False_lr_00001_weight_decay_00001/1.png
ADDED
|
Git LFS Details
|
models/steps_20_batch_200_RNN_4096_relu_rf_012_DoG_True_periodic_False_lr_00001_weight_decay_00001/most_recent_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8f7050cd3c660329c5c49e1bc935e66cb9eaebeb42caff69930d6ee793c8f6c1
|
| 3 |
+
size 83954107
|
place_cells.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import scipy
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class PlaceCells(object):
|
| 8 |
+
def __init__(self, options, us=None):
|
| 9 |
+
self.Np = options.Np
|
| 10 |
+
self.sigma = options.place_cell_rf
|
| 11 |
+
self.surround_scale = options.surround_scale
|
| 12 |
+
self.box_width = options.box_width
|
| 13 |
+
self.box_height = options.box_height
|
| 14 |
+
self.is_periodic = options.periodic
|
| 15 |
+
self.DoG = options.DoG
|
| 16 |
+
self.device = options.device
|
| 17 |
+
self.softmax = torch.nn.Softmax(dim=-1)
|
| 18 |
+
|
| 19 |
+
# Randomly tile place cell centers across environment
|
| 20 |
+
usx = np.random.uniform(-self.box_width / 2, self.box_width / 2, (self.Np,))
|
| 21 |
+
usy = np.random.uniform(-self.box_width / 2, self.box_width / 2, (self.Np,))
|
| 22 |
+
self.us = torch.tensor(np.vstack([usx, usy]).T)
|
| 23 |
+
# If using a GPU, put on GPU
|
| 24 |
+
self.us = self.us.to(self.device)
|
| 25 |
+
# self.us = torch.tensor(np.load('models/example_pc_centers.npy')).cuda()
|
| 26 |
+
|
| 27 |
+
def get_activation(self, pos):
|
| 28 |
+
"""
|
| 29 |
+
Get place cell activations for a given position.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
pos: 2d position of shape [batch_size, sequence_length, 2].
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
outputs: Place cell activations with shape [batch_size, sequence_length, Np].
|
| 36 |
+
"""
|
| 37 |
+
d = torch.abs(pos[:, :, None, :] - self.us[None, None, ...]).float()
|
| 38 |
+
|
| 39 |
+
if self.is_periodic:
|
| 40 |
+
dx = d[:, :, :, 0]
|
| 41 |
+
dy = d[:, :, :, 1]
|
| 42 |
+
dx = torch.minimum(dx, self.box_width - dx)
|
| 43 |
+
dy = torch.minimum(dy, self.box_height - dy)
|
| 44 |
+
d = torch.stack([dx, dy], axis=-1)
|
| 45 |
+
|
| 46 |
+
norm2 = (d**2).sum(-1)
|
| 47 |
+
|
| 48 |
+
# Normalize place cell outputs with prefactor alpha=1/2/np.pi/self.sigma**2,
|
| 49 |
+
# or, simply normalize with softmax, which yields same normalization on
|
| 50 |
+
# average and seems to speed up training.
|
| 51 |
+
outputs = self.softmax(-norm2 / (2 * self.sigma**2))
|
| 52 |
+
|
| 53 |
+
if self.DoG:
|
| 54 |
+
# Again, normalize with prefactor
|
| 55 |
+
# beta=1/2/np.pi/self.sigma**2/self.surround_scale, or use softmax.
|
| 56 |
+
outputs -= self.softmax(-norm2 / (2 * self.surround_scale * self.sigma**2))
|
| 57 |
+
|
| 58 |
+
# Shift and scale outputs so that they lie in [0,1].
|
| 59 |
+
min_output, _ = outputs.min(-1, keepdims=True)
|
| 60 |
+
outputs += torch.abs(min_output)
|
| 61 |
+
outputs /= outputs.sum(-1, keepdims=True)
|
| 62 |
+
return outputs
|
| 63 |
+
|
| 64 |
+
def get_nearest_cell_pos(self, activation, k=3):
|
| 65 |
+
"""
|
| 66 |
+
Decode position using centers of k maximally active place cells.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
activation: Place cell activations of shape [batch_size, sequence_length, Np].
|
| 70 |
+
k: Number of maximally active place cells with which to decode position.
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
pred_pos: Predicted 2d position with shape [batch_size, sequence_length, 2].
|
| 74 |
+
"""
|
| 75 |
+
_, idxs = torch.topk(activation, k=k)
|
| 76 |
+
pred_pos = self.us[idxs].mean(-2)
|
| 77 |
+
return pred_pos
|
| 78 |
+
|
| 79 |
+
def grid_pc(self, pc_outputs, res=32):
|
| 80 |
+
"""Interpolate place cell outputs onto a grid"""
|
| 81 |
+
coordsx = np.linspace(-self.box_width / 2, self.box_width / 2, res)
|
| 82 |
+
coordsy = np.linspace(-self.box_height / 2, self.box_height / 2, res)
|
| 83 |
+
grid_x, grid_y = np.meshgrid(coordsx, coordsy)
|
| 84 |
+
grid = np.stack([grid_x.ravel(), grid_y.ravel()]).T
|
| 85 |
+
|
| 86 |
+
# Convert to numpy
|
| 87 |
+
pc_outputs = pc_outputs.reshape(-1, self.Np)
|
| 88 |
+
|
| 89 |
+
T = pc_outputs.shape[0] # T vs transpose? What is T? (dim's?)
|
| 90 |
+
pc = np.zeros([T, res, res])
|
| 91 |
+
for i in range(len(pc_outputs)):
|
| 92 |
+
gridval = scipy.interpolate.griddata(self.us.cpu(), pc_outputs[i], grid)
|
| 93 |
+
pc[i] = gridval.reshape([res, res])
|
| 94 |
+
|
| 95 |
+
return pc
|
| 96 |
+
|
| 97 |
+
def compute_covariance(self, res=30):
|
| 98 |
+
"""Compute spatial covariance matrix of place cell outputs"""
|
| 99 |
+
pos = np.array(
|
| 100 |
+
np.meshgrid(
|
| 101 |
+
np.linspace(-self.box_width / 2, self.box_width / 2, res),
|
| 102 |
+
np.linspace(-self.box_height / 2, self.box_height / 2, res),
|
| 103 |
+
)
|
| 104 |
+
).T
|
| 105 |
+
|
| 106 |
+
pos = torch.tensor(pos)
|
| 107 |
+
|
| 108 |
+
# Put on GPU if available
|
| 109 |
+
pos = pos.to(self.device)
|
| 110 |
+
|
| 111 |
+
# Maybe specify dimensions here again?
|
| 112 |
+
pc_outputs = self.get_activation(pos).reshape(-1, self.Np).cpu()
|
| 113 |
+
|
| 114 |
+
C = pc_outputs @ pc_outputs.T
|
| 115 |
+
Csquare = C.reshape(res, res, res, res)
|
| 116 |
+
|
| 117 |
+
Cmean = np.zeros([res, res])
|
| 118 |
+
for i in range(res):
|
| 119 |
+
for j in range(res):
|
| 120 |
+
Cmean += np.roll(np.roll(Csquare[i, j], -i, axis=0), -j, axis=1)
|
| 121 |
+
|
| 122 |
+
Cmean = np.roll(np.roll(Cmean, res // 2, axis=0), res // 2, axis=1)
|
| 123 |
+
|
| 124 |
+
return Cmean
|
pyproject.toml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
authors = [{name = "Mio", email = "violet172309@gmail.com"}]
|
| 3 |
+
dependencies = ["numpy", "tensorflow", "scipy", "matplotlib", "imageio", "opencv-python", "tqdm", "torch"]
|
| 4 |
+
name = "grid-pattern-formation"
|
| 5 |
+
requires-python = ">= 3.11"
|
| 6 |
+
version = "0.1.0"
|
| 7 |
+
|
| 8 |
+
[build-system]
|
| 9 |
+
build-backend = "hatchling.build"
|
| 10 |
+
requires = ["hatchling"]
|
| 11 |
+
|
| 12 |
+
[tool.pixi.workspace]
|
| 13 |
+
channels = ["conda-forge"]
|
| 14 |
+
platforms = ["linux-64"]
|
| 15 |
+
|
| 16 |
+
[tool.pixi.system-requirements]
|
| 17 |
+
cuda = "12.6"
|
| 18 |
+
|
| 19 |
+
[tool.pixi.pypi-dependencies]
|
| 20 |
+
grid_pattern_formation = { path = ".", editable = true }
|
| 21 |
+
|
| 22 |
+
[tool.pixi.tasks]
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy
|
| 2 |
+
tensorflow
|
| 3 |
+
scipy
|
| 4 |
+
matplotlib
|
| 5 |
+
imageio
|
| 6 |
+
opencv-python
|
| 7 |
+
tqdm
|
| 8 |
+
torch
|
scores.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2018 Google LLC
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""Grid score calculations.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from __future__ import absolute_import
|
| 20 |
+
from __future__ import division
|
| 21 |
+
from __future__ import print_function
|
| 22 |
+
|
| 23 |
+
import math
|
| 24 |
+
import matplotlib.pyplot as plt
|
| 25 |
+
import numpy as np
|
| 26 |
+
import scipy.signal
|
| 27 |
+
import scipy.ndimage as ndimage
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def circle_mask(size, radius, in_val=1.0, out_val=0.0):
|
| 31 |
+
"""Calculating the grid scores with different radius."""
|
| 32 |
+
sz = [math.floor(size[0] / 2), math.floor(size[1] / 2)]
|
| 33 |
+
x = np.linspace(-sz[0], sz[1], size[1])
|
| 34 |
+
x = np.expand_dims(x, 0)
|
| 35 |
+
x = x.repeat(size[0], 0)
|
| 36 |
+
y = np.linspace(-sz[0], sz[1], size[1])
|
| 37 |
+
y = np.expand_dims(y, 1)
|
| 38 |
+
y = y.repeat(size[1], 1)
|
| 39 |
+
z = np.sqrt(x**2 + y**2)
|
| 40 |
+
z = np.less_equal(z, radius)
|
| 41 |
+
vfunc = np.vectorize(lambda b: b and in_val or out_val)
|
| 42 |
+
return vfunc(z)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class GridScorer(object):
|
| 46 |
+
"""Class for scoring ratemaps given trajectories."""
|
| 47 |
+
|
| 48 |
+
def __init__(self, nbins, coords_range, mask_parameters, min_max=False):
|
| 49 |
+
"""Scoring ratemaps given trajectories.
|
| 50 |
+
Args:
|
| 51 |
+
nbins: Number of bins per dimension in the ratemap.
|
| 52 |
+
coords_range: Environment coordinates range.
|
| 53 |
+
mask_parameters: parameters for the masks that analyze the angular
|
| 54 |
+
autocorrelation of the 2D autocorrelation.
|
| 55 |
+
min_max: Correction.
|
| 56 |
+
"""
|
| 57 |
+
self._nbins = nbins
|
| 58 |
+
self._min_max = min_max
|
| 59 |
+
self._coords_range = coords_range
|
| 60 |
+
self._corr_angles = [30, 45, 60, 90, 120, 135, 150]
|
| 61 |
+
# Create all masks
|
| 62 |
+
self._masks = [(self._get_ring_mask(mask_min, mask_max), (mask_min,
|
| 63 |
+
mask_max))
|
| 64 |
+
for mask_min, mask_max in mask_parameters]
|
| 65 |
+
# Mask for hiding the parts of the SAC that are never used
|
| 66 |
+
self._plotting_sac_mask = circle_mask(
|
| 67 |
+
[self._nbins * 2 - 1, self._nbins * 2 - 1],
|
| 68 |
+
self._nbins,
|
| 69 |
+
in_val=1.0,
|
| 70 |
+
out_val=np.nan)
|
| 71 |
+
|
| 72 |
+
def calculate_ratemap(self, xs, ys, activations, statistic='mean'):
|
| 73 |
+
return scipy.stats.binned_statistic_2d(
|
| 74 |
+
xs,
|
| 75 |
+
ys,
|
| 76 |
+
activations,
|
| 77 |
+
bins=self._nbins,
|
| 78 |
+
statistic=statistic,
|
| 79 |
+
range=self._coords_range)[0]
|
| 80 |
+
|
| 81 |
+
def _get_ring_mask(self, mask_min, mask_max):
|
| 82 |
+
n_points = [self._nbins * 2 - 1, self._nbins * 2 - 1]
|
| 83 |
+
return (circle_mask(n_points, mask_max * self._nbins) *
|
| 84 |
+
(1 - circle_mask(n_points, mask_min * self._nbins)))
|
| 85 |
+
|
| 86 |
+
def grid_score_60(self, corr):
|
| 87 |
+
if self._min_max:
|
| 88 |
+
return np.minimum(corr[60], corr[120]) - np.maximum(
|
| 89 |
+
corr[30], np.maximum(corr[90], corr[150]))
|
| 90 |
+
else:
|
| 91 |
+
return (corr[60] + corr[120]) / 2 - (corr[30] + corr[90] + corr[150]) / 3
|
| 92 |
+
|
| 93 |
+
def grid_score_90(self, corr):
|
| 94 |
+
return corr[90] - (corr[45] + corr[135]) / 2
|
| 95 |
+
|
| 96 |
+
def calculate_sac(self, seq1):
|
| 97 |
+
"""Calculating spatial autocorrelogram."""
|
| 98 |
+
seq2 = seq1
|
| 99 |
+
|
| 100 |
+
def filter2(b, x):
|
| 101 |
+
stencil = np.rot90(b, 2)
|
| 102 |
+
return scipy.signal.convolve2d(x, stencil, mode='full')
|
| 103 |
+
|
| 104 |
+
seq1 = np.nan_to_num(seq1)
|
| 105 |
+
seq2 = np.nan_to_num(seq2)
|
| 106 |
+
|
| 107 |
+
ones_seq1 = np.ones(seq1.shape)
|
| 108 |
+
ones_seq1[np.isnan(seq1)] = 0
|
| 109 |
+
ones_seq2 = np.ones(seq2.shape)
|
| 110 |
+
ones_seq2[np.isnan(seq2)] = 0
|
| 111 |
+
|
| 112 |
+
seq1[np.isnan(seq1)] = 0
|
| 113 |
+
seq2[np.isnan(seq2)] = 0
|
| 114 |
+
|
| 115 |
+
seq1_sq = np.square(seq1)
|
| 116 |
+
seq2_sq = np.square(seq2)
|
| 117 |
+
|
| 118 |
+
seq1_x_seq2 = filter2(seq1, seq2)
|
| 119 |
+
sum_seq1 = filter2(seq1, ones_seq2)
|
| 120 |
+
sum_seq2 = filter2(ones_seq1, seq2)
|
| 121 |
+
sum_seq1_sq = filter2(seq1_sq, ones_seq2)
|
| 122 |
+
sum_seq2_sq = filter2(ones_seq1, seq2_sq)
|
| 123 |
+
n_bins = filter2(ones_seq1, ones_seq2)
|
| 124 |
+
n_bins_sq = np.square(n_bins)
|
| 125 |
+
|
| 126 |
+
std_seq1 = np.power(
|
| 127 |
+
np.subtract(
|
| 128 |
+
np.divide(sum_seq1_sq, n_bins),
|
| 129 |
+
(np.divide(np.square(sum_seq1), n_bins_sq))), 0.5)
|
| 130 |
+
std_seq2 = np.power(
|
| 131 |
+
np.subtract(
|
| 132 |
+
np.divide(sum_seq2_sq, n_bins),
|
| 133 |
+
(np.divide(np.square(sum_seq2), n_bins_sq))), 0.5)
|
| 134 |
+
covar = np.subtract(
|
| 135 |
+
np.divide(seq1_x_seq2, n_bins),
|
| 136 |
+
np.divide(np.multiply(sum_seq1, sum_seq2), n_bins_sq))
|
| 137 |
+
x_coef = np.divide(covar, np.multiply(std_seq1, std_seq2))
|
| 138 |
+
x_coef = np.real(x_coef)
|
| 139 |
+
x_coef = np.nan_to_num(x_coef)
|
| 140 |
+
return x_coef
|
| 141 |
+
|
| 142 |
+
def rotated_sacs(self, sac, angles):
|
| 143 |
+
return [
|
| 144 |
+
scipy.ndimage.interpolation.rotate(sac, angle, reshape=False)
|
| 145 |
+
for angle in angles
|
| 146 |
+
]
|
| 147 |
+
|
| 148 |
+
def get_grid_scores_for_mask(self, sac, rotated_sacs, mask):
|
| 149 |
+
"""Calculate Pearson correlations of area inside mask at corr_angles."""
|
| 150 |
+
masked_sac = sac * mask
|
| 151 |
+
ring_area = np.sum(mask)
|
| 152 |
+
# Calculate dc on the ring area
|
| 153 |
+
masked_sac_mean = np.sum(masked_sac) / ring_area
|
| 154 |
+
# Center the sac values inside the ring
|
| 155 |
+
masked_sac_centered = (masked_sac - masked_sac_mean) * mask
|
| 156 |
+
variance = np.sum(masked_sac_centered**2) / ring_area + 1e-5
|
| 157 |
+
corrs = dict()
|
| 158 |
+
for angle, rotated_sac in zip(self._corr_angles, rotated_sacs):
|
| 159 |
+
masked_rotated_sac = (rotated_sac - masked_sac_mean) * mask
|
| 160 |
+
cross_prod = np.sum(masked_sac_centered * masked_rotated_sac) / ring_area
|
| 161 |
+
corrs[angle] = cross_prod / variance
|
| 162 |
+
return self.grid_score_60(corrs), self.grid_score_90(corrs), variance
|
| 163 |
+
|
| 164 |
+
def get_scores(self, rate_map):
|
| 165 |
+
"""Get summary of scrores for grid cells."""
|
| 166 |
+
sac = self.calculate_sac(rate_map)
|
| 167 |
+
rotated_sacs = self.rotated_sacs(sac, self._corr_angles)
|
| 168 |
+
|
| 169 |
+
scores = [
|
| 170 |
+
self.get_grid_scores_for_mask(sac, rotated_sacs, mask)
|
| 171 |
+
for mask, mask_params in self._masks # pylint: disable=unused-variable
|
| 172 |
+
]
|
| 173 |
+
scores_60, scores_90, variances = map(np.asarray, zip(*scores)) # pylint: disable=unused-variable
|
| 174 |
+
max_60_ind = np.argmax(scores_60)
|
| 175 |
+
max_90_ind = np.argmax(scores_90)
|
| 176 |
+
|
| 177 |
+
return (scores_60[max_60_ind], scores_90[max_90_ind],
|
| 178 |
+
self._masks[max_60_ind][1], self._masks[max_90_ind][1], sac, max_60_ind)
|
| 179 |
+
|
| 180 |
+
def plot_ratemap(self, ratemap, ax=None, title=None, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg
|
| 181 |
+
"""Plot ratemaps."""
|
| 182 |
+
if ax is None:
|
| 183 |
+
ax = plt.gca()
|
| 184 |
+
# Plot the ratemap
|
| 185 |
+
ax.imshow(ratemap, interpolation='none', *args, **kwargs)
|
| 186 |
+
# ax.pcolormesh(ratemap, *args, **kwargs)
|
| 187 |
+
ax.axis('off')
|
| 188 |
+
if title is not None:
|
| 189 |
+
ax.set_title(title)
|
| 190 |
+
|
| 191 |
+
def plot_sac(self,
|
| 192 |
+
sac,
|
| 193 |
+
mask_params=None,
|
| 194 |
+
ax=None,
|
| 195 |
+
title=None,
|
| 196 |
+
*args,
|
| 197 |
+
**kwargs): # pylint: disable=keyword-arg-before-vararg
|
| 198 |
+
"""Plot spatial autocorrelogram."""
|
| 199 |
+
if ax is None:
|
| 200 |
+
ax = plt.gca()
|
| 201 |
+
# Plot the sac
|
| 202 |
+
useful_sac = sac * self._plotting_sac_mask
|
| 203 |
+
ax.imshow(useful_sac, interpolation='none', *args, **kwargs)
|
| 204 |
+
# ax.pcolormesh(useful_sac, *args, **kwargs)
|
| 205 |
+
# Plot a ring for the adequate mask
|
| 206 |
+
if mask_params is not None:
|
| 207 |
+
center = self._nbins - 1
|
| 208 |
+
ax.add_artist(
|
| 209 |
+
plt.Circle(
|
| 210 |
+
(center, center),
|
| 211 |
+
mask_params[0] * self._nbins,
|
| 212 |
+
# lw=bump_size,
|
| 213 |
+
fill=False,
|
| 214 |
+
edgecolor='k'))
|
| 215 |
+
ax.add_artist(
|
| 216 |
+
plt.Circle(
|
| 217 |
+
(center, center),
|
| 218 |
+
mask_params[1] * self._nbins,
|
| 219 |
+
# lw=bump_size,
|
| 220 |
+
fill=False,
|
| 221 |
+
edgecolor='k'))
|
| 222 |
+
ax.axis('off')
|
| 223 |
+
if title is not None:
|
| 224 |
+
ax.set_title(title)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def border_score(rm, res, box_width):
|
| 228 |
+
# Find connected firing fields
|
| 229 |
+
pix_area = 100**2*box_width**2/res**2
|
| 230 |
+
rm_thresh = rm>(rm.max()*0.3)
|
| 231 |
+
rm_comps, ncomps = ndimage.measurements.label(rm_thresh)
|
| 232 |
+
|
| 233 |
+
# Keep fields with area > 200cm^2
|
| 234 |
+
masks = []
|
| 235 |
+
nfields = 0
|
| 236 |
+
for i in range(1,ncomps+1):
|
| 237 |
+
mask = (rm_comps==i).reshape(res,res)
|
| 238 |
+
if mask.sum()*pix_area > 200:
|
| 239 |
+
masks.append(mask)
|
| 240 |
+
nfields += 1
|
| 241 |
+
|
| 242 |
+
# Max coverage of any one field over any one border
|
| 243 |
+
cm_max = 0
|
| 244 |
+
for mask in masks:
|
| 245 |
+
mask = masks[0]
|
| 246 |
+
n_cov = mask[0].mean()
|
| 247 |
+
s_cov = mask[-1].mean()
|
| 248 |
+
e_cov = mask[:,0].mean()
|
| 249 |
+
w_cov = mask[:,-1].mean()
|
| 250 |
+
cm = np.max([n_cov,s_cov,e_cov,w_cov])
|
| 251 |
+
if cm>cm_max:
|
| 252 |
+
cm_max = cm
|
| 253 |
+
|
| 254 |
+
# Distance to nearest wall
|
| 255 |
+
x,y = np.mgrid[:res,:res] + 1
|
| 256 |
+
x = x.ravel()
|
| 257 |
+
y = y.ravel()
|
| 258 |
+
xmin = np.min(np.vstack([x,res+1-x]),0)
|
| 259 |
+
ymin = np.min(np.vstack([y,res+1-y]),0)
|
| 260 |
+
dweight = np.min(np.vstack([xmin,ymin]),0).reshape(res,res)
|
| 261 |
+
dweight = dweight*box_width/res
|
| 262 |
+
|
| 263 |
+
# Mean firing distance
|
| 264 |
+
dms = []
|
| 265 |
+
for mask in masks:
|
| 266 |
+
field = rm[mask]
|
| 267 |
+
field /= field.sum() # normalize
|
| 268 |
+
dm = (field*dweight[mask]).sum()
|
| 269 |
+
dms.append(dm)
|
| 270 |
+
dm = np.nanmean(dms) / (box_width/2)
|
| 271 |
+
border_score = (cm_max-dm)/(cm_max+dm)
|
| 272 |
+
return border_score, cm_max, dm
|
src/grid_pattern_formation/__init__.py
ADDED
|
File without changes
|
trainer.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import datetime
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
from visualize import save_ratemaps
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Trainer(object):
|
| 12 |
+
def __init__(self, options, model, trajectory_generator, restore=True):
|
| 13 |
+
self.options = options
|
| 14 |
+
self.model = model
|
| 15 |
+
self.trajectory_generator = trajectory_generator
|
| 16 |
+
lr = self.options.learning_rate
|
| 17 |
+
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
|
| 18 |
+
|
| 19 |
+
self.loss = []
|
| 20 |
+
self.err = []
|
| 21 |
+
|
| 22 |
+
# Set up checkpoints
|
| 23 |
+
self.ckpt_dir = os.path.join(options.save_dir, options.run_ID)
|
| 24 |
+
ckpt_path = os.path.join(self.ckpt_dir, "most_recent_model.pth")
|
| 25 |
+
pc_path = os.path.join(self.ckpt_dir, "place_cell_centers.npy")
|
| 26 |
+
if restore and os.path.isdir(self.ckpt_dir) and os.path.isfile(ckpt_path):
|
| 27 |
+
self.model.load_state_dict(torch.load(ckpt_path))
|
| 28 |
+
print("Restored trained model from {}".format(ckpt_path))
|
| 29 |
+
# 加载 place cell 位置
|
| 30 |
+
if os.path.isfile(pc_path):
|
| 31 |
+
us = np.load(pc_path)
|
| 32 |
+
self.model.place_cells.us = torch.tensor(us).to(options.device)
|
| 33 |
+
print("Restored place cell centers from {}".format(pc_path))
|
| 34 |
+
else:
|
| 35 |
+
print("Warning: place_cell_centers.npy not found! Model may not work correctly.")
|
| 36 |
+
else:
|
| 37 |
+
if not os.path.isdir(self.ckpt_dir):
|
| 38 |
+
os.makedirs(self.ckpt_dir, exist_ok=True)
|
| 39 |
+
print("Initializing new model from scratch.")
|
| 40 |
+
print("Saving to: {}".format(self.ckpt_dir))
|
| 41 |
+
|
| 42 |
+
def train_step(self, inputs, pc_outputs, pos):
|
| 43 |
+
"""
|
| 44 |
+
Train on one batch of trajectories.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
inputs: Batch of 2d velocity inputs with shape [batch_size, sequence_length, 2].
|
| 48 |
+
pc_outputs: Ground truth place cell activations with shape
|
| 49 |
+
[batch_size, sequence_length, Np].
|
| 50 |
+
pos: Ground truth 2d position with shape [batch_size, sequence_length, 2].
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
loss: Avg. loss for this training batch.
|
| 54 |
+
err: Avg. decoded position error in cm.
|
| 55 |
+
"""
|
| 56 |
+
self.model.zero_grad()
|
| 57 |
+
|
| 58 |
+
loss, err = self.model.compute_loss(inputs, pc_outputs, pos)
|
| 59 |
+
|
| 60 |
+
loss.backward()
|
| 61 |
+
self.optimizer.step()
|
| 62 |
+
|
| 63 |
+
return loss.item(), err.item()
|
| 64 |
+
|
| 65 |
+
def train(self, n_epochs: int = 1000, n_steps=10, save=True):
|
| 66 |
+
"""
|
| 67 |
+
Train model on simulated trajectories.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
n_steps: Number of training steps
|
| 71 |
+
save: If true, save a checkpoint after each epoch.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
# Construct generator
|
| 75 |
+
gen = self.trajectory_generator.get_generator()
|
| 76 |
+
|
| 77 |
+
for epoch_idx in range(n_epochs):
|
| 78 |
+
tbar = tqdm(range(n_steps), leave=False)
|
| 79 |
+
for step_idx in tbar:
|
| 80 |
+
inputs, pc_outputs, pos = next(gen)
|
| 81 |
+
loss, err = self.train_step(inputs, pc_outputs, pos)
|
| 82 |
+
self.loss.append(loss)
|
| 83 |
+
self.err.append(err)
|
| 84 |
+
|
| 85 |
+
# Log error rate to progress bar
|
| 86 |
+
tbar.set_description('Error = ' + str(int(100*err)) + 'cm')
|
| 87 |
+
|
| 88 |
+
if save and ((epoch_idx + 1) % 10 == 0 or epoch_idx == 0):
|
| 89 |
+
# Save checkpoint
|
| 90 |
+
# ckpt_path = os.path.join(
|
| 91 |
+
# self.ckpt_dir, "epoch_{}.pth".format(epoch_idx)
|
| 92 |
+
# )
|
| 93 |
+
# torch.save(self.model.state_dict(), ckpt_path)
|
| 94 |
+
torch.save(
|
| 95 |
+
self.model.state_dict(),
|
| 96 |
+
os.path.join(self.ckpt_dir, "most_recent_model.pth"),
|
| 97 |
+
)
|
| 98 |
+
# 保存 place cell 位置
|
| 99 |
+
np.save(
|
| 100 |
+
os.path.join(self.ckpt_dir, "place_cell_centers.npy"),
|
| 101 |
+
self.model.place_cells.us.cpu().numpy(),
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# Save a picture of rate maps
|
| 105 |
+
save_ratemaps(
|
| 106 |
+
self.model,
|
| 107 |
+
self.trajectory_generator,
|
| 108 |
+
self.options,
|
| 109 |
+
step=epoch_idx + 1,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
print(
|
| 113 |
+
"Epoch: {}/{}. Date: {}. Loss: {}. Err: {}cm".format(
|
| 114 |
+
epoch_idx + 1,
|
| 115 |
+
n_epochs,
|
| 116 |
+
str(datetime.datetime.now())[:-7],
|
| 117 |
+
np.round(loss, 2),
|
| 118 |
+
np.round(100 * err, 2),
|
| 119 |
+
)
|
| 120 |
+
)
|
trajectory_generator.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class TrajectoryGenerator(object):
|
| 8 |
+
def __init__(self, options, place_cells):
|
| 9 |
+
self.options = options
|
| 10 |
+
self.place_cells = place_cells
|
| 11 |
+
|
| 12 |
+
def plot_trajectory(self, traj, box_width, box_height, idx=0, step=2):
|
| 13 |
+
"""
|
| 14 |
+
Visualize one trajectory from traj dict.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
traj: dictionary containing trajectory info
|
| 18 |
+
box_width, box_height: dimensions of the environment
|
| 19 |
+
idx: which trajectory to plot from the batch (default: 0)
|
| 20 |
+
step: plot an arrow every 'step' frames
|
| 21 |
+
"""
|
| 22 |
+
# Extract trajectory for one rat
|
| 23 |
+
x = traj["target_x"][idx] # shape (samples,)
|
| 24 |
+
y = traj["target_y"][idx]
|
| 25 |
+
hd = traj["target_hd"][idx] # head directions in radians
|
| 26 |
+
|
| 27 |
+
# Also add starting point
|
| 28 |
+
x0 = traj["init_x"][idx, 0]
|
| 29 |
+
y0 = traj["init_y"][idx, 0]
|
| 30 |
+
hd0 = traj["init_hd"][idx, 0]
|
| 31 |
+
|
| 32 |
+
x = np.concatenate([[x0], x])
|
| 33 |
+
y = np.concatenate([[y0], y])
|
| 34 |
+
hd = np.concatenate([[hd0], hd])
|
| 35 |
+
|
| 36 |
+
# Plot trajectory
|
| 37 |
+
plt.figure(figsize=(6, 6))
|
| 38 |
+
plt.plot(x, y, "-o", markersize=2, label="trajectory")
|
| 39 |
+
|
| 40 |
+
# Add arrows for head direction
|
| 41 |
+
for t in range(0, len(x), step):
|
| 42 |
+
dx = 0.1 * np.cos(hd[t])
|
| 43 |
+
dy = 0.1 * np.sin(hd[t])
|
| 44 |
+
plt.arrow(
|
| 45 |
+
x[t], y[t], dx, dy, head_width=0.05, head_length=0.08, fc="r", ec="r"
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# Draw box boundaries
|
| 49 |
+
plt.axhline(y=-box_height / 2, color="k")
|
| 50 |
+
plt.axhline(y=box_height / 2, color="k")
|
| 51 |
+
plt.axvline(x=-box_width / 2, color="k")
|
| 52 |
+
plt.axvline(x=box_width / 2, color="k")
|
| 53 |
+
|
| 54 |
+
plt.xlim([-box_width / 2 - 0.2, box_width / 2 + 0.2])
|
| 55 |
+
plt.ylim([-box_height / 2 - 0.2, box_height / 2 + 0.2])
|
| 56 |
+
plt.gca().set_aspect("equal", adjustable="box")
|
| 57 |
+
plt.xlabel("x position (m)")
|
| 58 |
+
plt.ylabel("y position (m)")
|
| 59 |
+
plt.title(f"Trajectory {idx}")
|
| 60 |
+
plt.legend()
|
| 61 |
+
plt.show()
|
| 62 |
+
|
| 63 |
+
def avoid_wall(self, position, hd, box_width, box_height):
|
| 64 |
+
"""
|
| 65 |
+
Compute distance and angle to nearest wall
|
| 66 |
+
"""
|
| 67 |
+
x = position[:, 0]
|
| 68 |
+
y = position[:, 1]
|
| 69 |
+
dists = [
|
| 70 |
+
box_width / 2 - x,
|
| 71 |
+
box_height / 2 - y,
|
| 72 |
+
box_width / 2 + x,
|
| 73 |
+
box_height / 2 + y,
|
| 74 |
+
]
|
| 75 |
+
d_wall = np.min(dists, axis=0)
|
| 76 |
+
angles = np.arange(4) * np.pi / 2
|
| 77 |
+
theta = angles[np.argmin(dists, axis=0)]
|
| 78 |
+
hd = np.mod(hd, 2 * np.pi)
|
| 79 |
+
a_wall = hd - theta
|
| 80 |
+
a_wall = np.mod(a_wall + np.pi, 2 * np.pi) - np.pi
|
| 81 |
+
|
| 82 |
+
is_near_wall = (d_wall < self.border_region) * (np.abs(a_wall) < np.pi / 2)
|
| 83 |
+
turn_angle = np.zeros_like(hd)
|
| 84 |
+
turn_angle[is_near_wall] = np.sign(a_wall[is_near_wall]) * (
|
| 85 |
+
np.pi / 2 - np.abs(a_wall[is_near_wall])
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
return is_near_wall, turn_angle
|
| 89 |
+
|
| 90 |
+
def generate_trajectory(self, box_width, box_height, batch_size):
|
| 91 |
+
"""Generate a random walk in a rectangular box"""
|
| 92 |
+
samples = self.options.sequence_length
|
| 93 |
+
dt = 0.02 # time step increment (seconds)
|
| 94 |
+
sigma = 5.76 * 2 # stdev rotation velocity (rads/sec)
|
| 95 |
+
b = 0.13 * 2 * np.pi # forward velocity rayleigh dist scale (m/sec)
|
| 96 |
+
mu = 0 # turn angle bias
|
| 97 |
+
self.border_region = 0.03 # meters
|
| 98 |
+
|
| 99 |
+
# Initialize variables
|
| 100 |
+
position = np.zeros([batch_size, samples + 2, 2])
|
| 101 |
+
head_dir = np.zeros([batch_size, samples + 2])
|
| 102 |
+
position[:, 0, 0] = np.random.uniform(-box_width / 2, box_width / 2, batch_size)
|
| 103 |
+
position[:, 0, 1] = np.random.uniform(
|
| 104 |
+
-box_height / 2, box_height / 2, batch_size
|
| 105 |
+
)
|
| 106 |
+
head_dir[:, 0] = np.random.uniform(0, 2 * np.pi, batch_size)
|
| 107 |
+
velocity = np.zeros([batch_size, samples + 2])
|
| 108 |
+
|
| 109 |
+
# Generate sequence of random boosts and turns
|
| 110 |
+
random_turn = np.random.normal(mu, sigma, [batch_size, samples + 1])
|
| 111 |
+
random_vel = np.random.rayleigh(b, [batch_size, samples + 1])
|
| 112 |
+
v = np.abs(np.random.normal(0, b * np.pi / 2, batch_size))
|
| 113 |
+
|
| 114 |
+
for t in range(samples + 1):
|
| 115 |
+
# Update velocity
|
| 116 |
+
v = random_vel[:, t]
|
| 117 |
+
turn_angle = np.zeros(batch_size)
|
| 118 |
+
|
| 119 |
+
if not self.options.periodic:
|
| 120 |
+
# If in border region, turn and slow down
|
| 121 |
+
is_near_wall, turn_angle = self.avoid_wall(
|
| 122 |
+
position[:, t], head_dir[:, t], box_width, box_height
|
| 123 |
+
)
|
| 124 |
+
v[is_near_wall] *= 0.25
|
| 125 |
+
|
| 126 |
+
# Update turn angle
|
| 127 |
+
turn_angle += dt * random_turn[:, t]
|
| 128 |
+
|
| 129 |
+
# Take a step
|
| 130 |
+
velocity[:, t] = v * dt
|
| 131 |
+
update = velocity[:, t, None] * np.stack(
|
| 132 |
+
[np.cos(head_dir[:, t]), np.sin(head_dir[:, t])], axis=-1
|
| 133 |
+
)
|
| 134 |
+
position[:, t + 1] = position[:, t] + update
|
| 135 |
+
|
| 136 |
+
# Rotate head direction
|
| 137 |
+
head_dir[:, t + 1] = head_dir[:, t] + turn_angle
|
| 138 |
+
|
| 139 |
+
# Periodic boundaries
|
| 140 |
+
if self.options.periodic:
|
| 141 |
+
position[:, :, 0] = (
|
| 142 |
+
np.mod(position[:, :, 0] + box_width / 2, box_width) - box_width / 2
|
| 143 |
+
)
|
| 144 |
+
position[:, :, 1] = (
|
| 145 |
+
np.mod(position[:, :, 1] + box_height / 2, box_height) - box_height / 2
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
head_dir = np.mod(head_dir + np.pi, 2 * np.pi) - np.pi # Periodic variable
|
| 149 |
+
|
| 150 |
+
traj = {}
|
| 151 |
+
# Input variables
|
| 152 |
+
traj["init_hd"] = head_dir[:, 0, None]
|
| 153 |
+
traj["init_x"] = position[:, 1, 0, None]
|
| 154 |
+
traj["init_y"] = position[:, 1, 1, None]
|
| 155 |
+
|
| 156 |
+
traj["ego_v"] = velocity[:, 1:-1]
|
| 157 |
+
ang_v = np.diff(head_dir, axis=-1)
|
| 158 |
+
traj["phi_x"], traj["phi_y"] = np.cos(ang_v)[:, :-1], np.sin(ang_v)[:, :-1]
|
| 159 |
+
|
| 160 |
+
# Target variables
|
| 161 |
+
traj["target_hd"] = head_dir[:, 1:-1]
|
| 162 |
+
traj["target_x"] = position[:, 2:, 0]
|
| 163 |
+
traj["target_y"] = position[:, 2:, 1]
|
| 164 |
+
|
| 165 |
+
# for i in range(5):
|
| 166 |
+
# self.plot_trajectory(traj, box_width, box_height, i)
|
| 167 |
+
# raise Exception("dog")
|
| 168 |
+
|
| 169 |
+
return traj
|
| 170 |
+
|
| 171 |
+
def get_generator(self, batch_size=None, box_width=None, box_height=None):
|
| 172 |
+
"""
|
| 173 |
+
Returns a generator that yields batches of trajectories
|
| 174 |
+
"""
|
| 175 |
+
if not batch_size:
|
| 176 |
+
batch_size = self.options.batch_size
|
| 177 |
+
if not box_width:
|
| 178 |
+
box_width = self.options.box_width
|
| 179 |
+
if not box_height:
|
| 180 |
+
box_height = self.options.box_height
|
| 181 |
+
|
| 182 |
+
while True:
|
| 183 |
+
traj = self.generate_trajectory(box_width, box_height, batch_size)
|
| 184 |
+
|
| 185 |
+
v = np.stack(
|
| 186 |
+
[
|
| 187 |
+
traj["ego_v"] * np.cos(traj["target_hd"]),
|
| 188 |
+
traj["ego_v"] * np.sin(traj["target_hd"]),
|
| 189 |
+
],
|
| 190 |
+
axis=-1,
|
| 191 |
+
)
|
| 192 |
+
v = torch.tensor(v, dtype=torch.float32).transpose(0, 1)
|
| 193 |
+
|
| 194 |
+
pos = np.stack([traj["target_x"], traj["target_y"]], axis=-1)
|
| 195 |
+
pos = torch.tensor(pos, dtype=torch.float32).transpose(0, 1)
|
| 196 |
+
# Put on GPU if GPU is available
|
| 197 |
+
pos = pos.to(self.options.device)
|
| 198 |
+
place_outputs = self.place_cells.get_activation(pos)
|
| 199 |
+
|
| 200 |
+
init_pos = np.stack([traj["init_x"], traj["init_y"]], axis=-1)
|
| 201 |
+
init_pos = torch.tensor(init_pos, dtype=torch.float32)
|
| 202 |
+
init_pos = init_pos.to(self.options.device)
|
| 203 |
+
init_actv = self.place_cells.get_activation(init_pos).squeeze()
|
| 204 |
+
|
| 205 |
+
v = v.to(self.options.device)
|
| 206 |
+
inputs = (v, init_actv)
|
| 207 |
+
|
| 208 |
+
yield (inputs, place_outputs, pos)
|
| 209 |
+
|
| 210 |
+
def get_test_batch(self, batch_size=None, box_width=None, box_height=None):
|
| 211 |
+
"""For testing performance, returns a batch of smample trajectories"""
|
| 212 |
+
if not batch_size:
|
| 213 |
+
batch_size = self.options.batch_size
|
| 214 |
+
if not box_width:
|
| 215 |
+
box_width = self.options.box_width
|
| 216 |
+
if not box_height:
|
| 217 |
+
box_height = self.options.box_height
|
| 218 |
+
|
| 219 |
+
traj = self.generate_trajectory(box_width, box_height, batch_size)
|
| 220 |
+
|
| 221 |
+
v = np.stack(
|
| 222 |
+
[
|
| 223 |
+
traj["ego_v"] * np.cos(traj["target_hd"]),
|
| 224 |
+
traj["ego_v"] * np.sin(traj["target_hd"]),
|
| 225 |
+
],
|
| 226 |
+
axis=-1,
|
| 227 |
+
)
|
| 228 |
+
v = torch.tensor(v, dtype=torch.float32).transpose(0, 1)
|
| 229 |
+
|
| 230 |
+
pos = np.stack([traj["target_x"], traj["target_y"]], axis=-1)
|
| 231 |
+
pos = torch.tensor(pos, dtype=torch.float32).transpose(0, 1)
|
| 232 |
+
pos = pos.to(self.options.device)
|
| 233 |
+
place_outputs = self.place_cells.get_activation(pos)
|
| 234 |
+
|
| 235 |
+
init_pos = np.stack([traj["init_x"], traj["init_y"]], axis=-1)
|
| 236 |
+
init_pos = torch.tensor(init_pos, dtype=torch.float32)
|
| 237 |
+
init_pos = init_pos.to(self.options.device)
|
| 238 |
+
init_actv = self.place_cells.get_activation(init_pos).squeeze()
|
| 239 |
+
|
| 240 |
+
v = v.to(self.options.device)
|
| 241 |
+
inputs = (v, init_actv)
|
| 242 |
+
|
| 243 |
+
return (inputs, pos, place_outputs)
|
utils.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def generate_run_ID(options):
|
| 5 |
+
'''
|
| 6 |
+
Create a unique run ID from the most relevant
|
| 7 |
+
parameters. Remaining parameters can be found in
|
| 8 |
+
params.npy file.
|
| 9 |
+
'''
|
| 10 |
+
params = [
|
| 11 |
+
'steps', str(options.sequence_length),
|
| 12 |
+
'batch', str(options.batch_size),
|
| 13 |
+
options.RNN_type,
|
| 14 |
+
str(options.Ng),
|
| 15 |
+
options.activation,
|
| 16 |
+
'rf', str(options.place_cell_rf),
|
| 17 |
+
'DoG', str(options.DoG),
|
| 18 |
+
'periodic', str(options.periodic),
|
| 19 |
+
'lr', str(options.learning_rate),
|
| 20 |
+
'weight_decay', str(options.weight_decay),
|
| 21 |
+
]
|
| 22 |
+
separator = '_'
|
| 23 |
+
run_ID = separator.join(params)
|
| 24 |
+
run_ID = run_ID.replace('.', '')
|
| 25 |
+
|
| 26 |
+
return run_ID
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_2d_sort(x1,x2):
|
| 30 |
+
"""
|
| 31 |
+
Reshapes x1 and x2 into square arrays, and then sorts
|
| 32 |
+
them such that x1 increases downward and x2 increases
|
| 33 |
+
rightward. Returns the order.
|
| 34 |
+
"""
|
| 35 |
+
n = int(np.round(np.sqrt(len(x1))))
|
| 36 |
+
total_order = x1.argsort()
|
| 37 |
+
total_order = total_order.reshape(n,n)
|
| 38 |
+
for i in range(n):
|
| 39 |
+
row_order = x2[total_order.ravel()].reshape(n,n)[i].argsort()
|
| 40 |
+
total_order[i] = total_order[i,row_order]
|
| 41 |
+
total_order = total_order.ravel()
|
| 42 |
+
return total_order
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def dft(N,real=False,scale='sqrtn'):
|
| 46 |
+
if not real:
|
| 47 |
+
return scipy.linalg.dft(N,scale)
|
| 48 |
+
else:
|
| 49 |
+
cosines = np.cos(2*np.pi*np.arange(N//2+1)[None,:]/N*np.arange(N)[:,None])
|
| 50 |
+
sines = np.sin(2*np.pi*np.arange(1,(N-1)//2+1)[None,:]/N*np.arange(N)[:,None])
|
| 51 |
+
if N%2==0:
|
| 52 |
+
cosines[:,-1] /= np.sqrt(2)
|
| 53 |
+
F = np.concatenate((cosines,sines[:,::-1]),1)
|
| 54 |
+
F[:,0] /= np.sqrt(N)
|
| 55 |
+
F[:,1:] /= np.sqrt(N/2)
|
| 56 |
+
return F
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def skaggs_power(Jsort):
|
| 60 |
+
F = dft(int(np.sqrt(N)), real=True)
|
| 61 |
+
F2d = F[:,None,:,None]*F[None,:,None,:]
|
| 62 |
+
|
| 63 |
+
F2d_unroll = np.reshape(F2d, (N, N))
|
| 64 |
+
|
| 65 |
+
F2d_inv = F2d_unroll.conj().T
|
| 66 |
+
Jtilde = F2d_inv.dot(Jsort).dot(F2d_unroll)
|
| 67 |
+
|
| 68 |
+
return (Jtilde[1,1]**2 + Jtilde[-1,-1]**2) / (Jtilde**2).sum()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def skaggs_power_2(Jsort):
|
| 72 |
+
J_square = np.reshape(Jsort, (n,n,n,n))
|
| 73 |
+
Jmean = np.zeros([n,n])
|
| 74 |
+
for i in range(n):
|
| 75 |
+
for j in range(n):
|
| 76 |
+
Jmean += np.roll(np.roll(J_square[i,j], -i, axis=0), -j, axis=1)
|
| 77 |
+
|
| 78 |
+
# Jmean[0,0] = np.max(Jmean[1:,1:])
|
| 79 |
+
Jmean = np.roll(np.roll(Jmean, n//2, axis=0), n//2, axis=1)
|
| 80 |
+
Jtilde = np.real(np.fft.fft2(Jmean))
|
| 81 |
+
|
| 82 |
+
Jtilde[0,0] = 0
|
| 83 |
+
sk_power = Jtilde[1,1]**2 + Jtilde[0,1]**2 + Jtilde[1,0]**2
|
| 84 |
+
sk_power += Jtilde[-1,-1]**2 + Jtilde[0,-1]**2 + Jtilde[-1,0]**2
|
| 85 |
+
sk_power /= (Jtilde**2).sum()
|
| 86 |
+
|
| 87 |
+
return sk_power
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def calc_err():
|
| 91 |
+
inputs, _, pos = next(gen)
|
| 92 |
+
pred = model(inputs)
|
| 93 |
+
pred_pos = place_cells.get_nearest_cell_pos(pred)
|
| 94 |
+
return tf.reduce_mean(tf.sqrt(tf.reduce_sum((pos - pred_pos)**2, axis=-1)))
|
| 95 |
+
|
| 96 |
+
from visualize import compute_ratemaps, plot_ratemaps
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def compute_variance(res, n_avg):
|
| 100 |
+
|
| 101 |
+
activations, rate_map, g, pos = compute_ratemaps(model, data_manager, options, res=res, n_avg=n_avg)
|
| 102 |
+
|
| 103 |
+
counts = np.zeros([res,res])
|
| 104 |
+
variance = np.zeros([res,res])
|
| 105 |
+
|
| 106 |
+
x_all = (pos[:,0] + options['box_width']/2) / options['box_width'] * res
|
| 107 |
+
y_all = (pos[:,1] + options['box_height']/2) / options['box_height'] * res
|
| 108 |
+
for i in tqdm(range(len(g))):
|
| 109 |
+
x = int(x_all[i])
|
| 110 |
+
y = int(y_all[i])
|
| 111 |
+
if x >=0 and x < res and y >=0 and y < res:
|
| 112 |
+
counts[x, y] += 1
|
| 113 |
+
variance[x, y] += np.linalg.norm(g[i] - activations[:, x, y]) / np.linalg.norm(g[i]) / np.linalg.norm(activations[:,x,y])
|
| 114 |
+
|
| 115 |
+
for x in range(res):
|
| 116 |
+
for y in range(res):
|
| 117 |
+
if counts[x, y] > 0:
|
| 118 |
+
variance[x, y] /= counts[x, y]
|
| 119 |
+
|
| 120 |
+
return variance
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def load_trained_weights(model, trainer, weight_dir):
|
| 124 |
+
''' Load weights stored as a .npy file (for github)'''
|
| 125 |
+
|
| 126 |
+
# Train for a single step to initialize weights
|
| 127 |
+
# trainer.train(n_epochs=1, n_steps=1, save=False)
|
| 128 |
+
|
| 129 |
+
# Load weights from npy array
|
| 130 |
+
weights = np.load(weight_dir, allow_pickle=True)
|
| 131 |
+
model.set_weights(weights)
|
| 132 |
+
print('Loaded trained weights.')
|
| 133 |
+
|
visualize.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
import numpy as np
|
| 3 |
+
from matplotlib import pyplot as plt
|
| 4 |
+
|
| 5 |
+
import scipy
|
| 6 |
+
import scipy.stats
|
| 7 |
+
from imageio import imsave
|
| 8 |
+
import cv2
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def concat_images(images, image_width, spacer_size):
|
| 12 |
+
""" Concat image horizontally with spacer """
|
| 13 |
+
spacer = np.ones([image_width, spacer_size, 4], dtype=np.uint8) * 255
|
| 14 |
+
images_with_spacers = []
|
| 15 |
+
|
| 16 |
+
image_size = len(images)
|
| 17 |
+
|
| 18 |
+
for i in range(image_size):
|
| 19 |
+
images_with_spacers.append(images[i])
|
| 20 |
+
if i != image_size - 1:
|
| 21 |
+
# Add spacer
|
| 22 |
+
images_with_spacers.append(spacer)
|
| 23 |
+
ret = np.hstack(images_with_spacers)
|
| 24 |
+
return ret
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def concat_images_in_rows(images, row_size, image_width, spacer_size=4):
|
| 28 |
+
""" Concat images in rows """
|
| 29 |
+
column_size = len(images) // row_size
|
| 30 |
+
spacer_h = np.ones([spacer_size, image_width*column_size + (column_size-1)*spacer_size, 4],
|
| 31 |
+
dtype=np.uint8) * 255
|
| 32 |
+
|
| 33 |
+
row_images_with_spacers = []
|
| 34 |
+
|
| 35 |
+
for row in range(row_size):
|
| 36 |
+
row_images = images[column_size*row:column_size*row+column_size]
|
| 37 |
+
row_concated_images = concat_images(row_images, image_width, spacer_size)
|
| 38 |
+
row_images_with_spacers.append(row_concated_images)
|
| 39 |
+
|
| 40 |
+
if row != row_size-1:
|
| 41 |
+
row_images_with_spacers.append(spacer_h)
|
| 42 |
+
|
| 43 |
+
ret = np.vstack(row_images_with_spacers)
|
| 44 |
+
return ret
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def convert_to_colormap(im, cmap):
|
| 48 |
+
im = cmap(im)
|
| 49 |
+
im = np.uint8(im * 255)
|
| 50 |
+
return im
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def rgb(im, cmap='jet', smooth=True):
|
| 54 |
+
cmap = plt.cm.get_cmap(cmap)
|
| 55 |
+
np.seterr(invalid='ignore') # ignore divide by zero err
|
| 56 |
+
im = (im - np.min(im)) / (np.max(im) - np.min(im))
|
| 57 |
+
if smooth:
|
| 58 |
+
im = cv2.GaussianBlur(im, (3,3), sigmaX=1, sigmaY=0)
|
| 59 |
+
im = cmap(im)
|
| 60 |
+
im = np.uint8(im * 255)
|
| 61 |
+
return im
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def plot_ratemaps(activations, n_plots, cmap='jet', smooth=True, width=16):
|
| 65 |
+
images = [rgb(im, cmap, smooth) for im in activations[:n_plots]]
|
| 66 |
+
rm_fig = concat_images_in_rows(images, n_plots//width, activations.shape[-1])
|
| 67 |
+
return rm_fig
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def compute_ratemaps(model, trajectory_generator, options, res=20, n_avg=None, Ng=512, idxs=None, return_raw=False):
|
| 71 |
+
'''Compute spatial firing fields
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
model: The RNN model
|
| 75 |
+
trajectory_generator: Generator for test trajectories
|
| 76 |
+
options: Training options
|
| 77 |
+
res: Resolution of the rate map grid
|
| 78 |
+
n_avg: Number of batches to average over
|
| 79 |
+
Ng: Number of grid cells to analyze
|
| 80 |
+
idxs: Indices of specific grid cells to analyze
|
| 81 |
+
return_raw: If True, also return raw activations (g) and positions (pos).
|
| 82 |
+
Warning: This uses significant memory for large batch_size/n_avg.
|
| 83 |
+
If False, returns None for g and pos to save memory.
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
activations: Spatial firing fields [Ng, res, res]
|
| 87 |
+
rate_map: Flattened rate maps [Ng, res*res]
|
| 88 |
+
g: Raw activations (None if return_raw=False)
|
| 89 |
+
pos: Raw positions (None if return_raw=False)
|
| 90 |
+
'''
|
| 91 |
+
|
| 92 |
+
if not n_avg:
|
| 93 |
+
n_avg = 1000 // options.sequence_length
|
| 94 |
+
|
| 95 |
+
if not np.any(idxs):
|
| 96 |
+
idxs = np.arange(Ng)
|
| 97 |
+
idxs = idxs[:Ng]
|
| 98 |
+
|
| 99 |
+
# Only allocate large arrays if return_raw is True
|
| 100 |
+
if return_raw:
|
| 101 |
+
g = np.zeros([n_avg, options.batch_size * options.sequence_length, Ng])
|
| 102 |
+
pos = np.zeros([n_avg, options.batch_size * options.sequence_length, 2])
|
| 103 |
+
else:
|
| 104 |
+
g = None
|
| 105 |
+
pos = None
|
| 106 |
+
|
| 107 |
+
activations = np.zeros([Ng, res, res])
|
| 108 |
+
counts = np.zeros([res, res])
|
| 109 |
+
|
| 110 |
+
for index in range(n_avg):
|
| 111 |
+
inputs, pos_batch, _ = trajectory_generator.get_test_batch()
|
| 112 |
+
g_batch = model.g(inputs).detach().cpu().numpy()
|
| 113 |
+
|
| 114 |
+
pos_batch = np.reshape(pos_batch.cpu(), [-1, 2])
|
| 115 |
+
g_batch = g_batch[:,:,idxs].reshape(-1, Ng)
|
| 116 |
+
|
| 117 |
+
if return_raw:
|
| 118 |
+
g[index] = g_batch
|
| 119 |
+
pos[index] = pos_batch
|
| 120 |
+
|
| 121 |
+
x_batch = (pos_batch[:,0] + options.box_width/2) / (options.box_width) * res
|
| 122 |
+
y_batch = (pos_batch[:,1] + options.box_height/2) / (options.box_height) * res
|
| 123 |
+
|
| 124 |
+
for i in range(options.batch_size*options.sequence_length):
|
| 125 |
+
x = x_batch[i]
|
| 126 |
+
y = y_batch[i]
|
| 127 |
+
if x >=0 and x < res and y >=0 and y < res:
|
| 128 |
+
counts[int(x), int(y)] += 1
|
| 129 |
+
activations[:, int(x), int(y)] += g_batch[i, :]
|
| 130 |
+
|
| 131 |
+
for x in range(res):
|
| 132 |
+
for y in range(res):
|
| 133 |
+
if counts[x, y] > 0:
|
| 134 |
+
activations[:, x, y] /= counts[x, y]
|
| 135 |
+
|
| 136 |
+
if return_raw:
|
| 137 |
+
g = g.reshape([-1, Ng])
|
| 138 |
+
pos = pos.reshape([-1, 2])
|
| 139 |
+
|
| 140 |
+
# # scipy binned_statistic_2d is slightly slower
|
| 141 |
+
# activations = scipy.stats.binned_statistic_2d(pos[:,0], pos[:,1], g.T, bins=res)[0]
|
| 142 |
+
rate_map = activations.reshape(Ng, -1)
|
| 143 |
+
|
| 144 |
+
return activations, rate_map, g, pos
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def save_ratemaps(model, trajectory_generator, options, step, res=20, n_avg=None):
|
| 148 |
+
if not n_avg:
|
| 149 |
+
n_avg = 1000 // options.sequence_length
|
| 150 |
+
activations, rate_map, g, pos = compute_ratemaps(model, trajectory_generator,
|
| 151 |
+
options, res=res, n_avg=n_avg)
|
| 152 |
+
rm_fig = plot_ratemaps(activations, n_plots=len(activations))
|
| 153 |
+
imdir = options.save_dir + "/" + options.run_ID
|
| 154 |
+
imsave(imdir + "/" + str(step) + ".png", rm_fig)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def save_autocorr(sess, model, save_name, trajectory_generator, step, flags):
|
| 158 |
+
starts = [0.2] * 10
|
| 159 |
+
ends = np.linspace(0.4, 1.0, num=10)
|
| 160 |
+
coord_range=((-1.1, 1.1), (-1.1, 1.1))
|
| 161 |
+
masks_parameters = zip(starts, ends.tolist())
|
| 162 |
+
latest_epoch_scorer = scores.GridScorer(20, coord_range, masks_parameters)
|
| 163 |
+
|
| 164 |
+
res = dict()
|
| 165 |
+
index_size = 100
|
| 166 |
+
for _ in range(index_size):
|
| 167 |
+
feed_dict = trajectory_generator.feed_dict(flags.box_width, flags.box_height)
|
| 168 |
+
mb_res = sess.run({
|
| 169 |
+
'pos_xy': model.target_pos,
|
| 170 |
+
'bottleneck': model.g,
|
| 171 |
+
}, feed_dict=feed_dict)
|
| 172 |
+
res = utils.concat_dict(res, mb_res)
|
| 173 |
+
|
| 174 |
+
filename = save_name + '/autocorrs_' + str(step) + '.pdf'
|
| 175 |
+
imdir = flags.save_dir + '/'
|
| 176 |
+
out = utils.get_scores_and_plot(
|
| 177 |
+
latest_epoch_scorer, res['pos_xy'], res['bottleneck'],
|
| 178 |
+
imdir, filename)
|